mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert D25074763: [WIP] Update foreach APIs to use scalar lists
Test Plan: revert-hammer
Differential Revision:
D25074763 (cce84b5ca5)
Original commit changeset: 155e3d2073a2
fbshipit-source-id: ef0d153e2740b50bd4a95f7a57c370bb5da46355
This commit is contained in:
parent
d1bc1ab8ca
commit
443a431ac3
|
|
@ -26,7 +26,7 @@ std::vector<Tensor> foreach_tensor_##OP##_scalar_kernel_slow(TensorList tensors,
|
|||
}
|
||||
|
||||
#define FOREACH_BINARY_OP_SCALARLIST(OP) \
|
||||
void foreach_tensor_##OP##_scalarlist_kernel_slow_(TensorList tensors, at::ArrayRef<Scalar> scalars) { \
|
||||
void foreach_tensor_##OP##_scalarlist_kernel_slow_(TensorList tensors, at::ArrayRef<double> scalars) { \
|
||||
check_foreach_api_restrictions(tensors, scalars); \
|
||||
\
|
||||
for (const auto i : c10::irange(tensors.size())) { \
|
||||
|
|
@ -34,7 +34,7 @@ void foreach_tensor_##OP##_scalarlist_kernel_slow_(TensorList tensors, at::Array
|
|||
} \
|
||||
} \
|
||||
\
|
||||
std::vector<Tensor> foreach_tensor_##OP##_scalarlist_kernel_slow(TensorList tensors, at::ArrayRef<Scalar> scalars) { \
|
||||
std::vector<Tensor> foreach_tensor_##OP##_scalarlist_kernel_slow(TensorList tensors, at::ArrayRef<double> scalars) { \
|
||||
check_foreach_api_restrictions(tensors, scalars); \
|
||||
std::vector<Tensor> result; \
|
||||
result.reserve(tensors.size()); \
|
||||
|
|
@ -129,7 +129,7 @@ void foreach_tensor_##OP##_scalar_slow_(TensorList input, TensorList tensors1, T
|
|||
} \
|
||||
|
||||
#define FOREACH_POINTWISE_OP_SCALARLIST(OP) \
|
||||
std::vector<Tensor> foreach_tensor_##OP##_scalarlist_slow(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<Scalar> scalars) { \
|
||||
std::vector<Tensor> foreach_tensor_##OP##_scalarlist_slow(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<double> scalars) { \
|
||||
check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \
|
||||
\
|
||||
std::vector<Tensor> result; \
|
||||
|
|
@ -140,7 +140,7 @@ std::vector<Tensor> foreach_tensor_##OP##_scalarlist_slow(TensorList input, Tens
|
|||
return result; \
|
||||
} \
|
||||
\
|
||||
void foreach_tensor_##OP##_scalarlist_slow_(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<Scalar> scalars) { \
|
||||
void foreach_tensor_##OP##_scalarlist_slow_(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<double> scalars) { \
|
||||
check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \
|
||||
\
|
||||
for(const auto i : c10::irange(input.size())) { \
|
||||
|
|
|
|||
|
|
@ -4,11 +4,7 @@
|
|||
namespace at {
|
||||
namespace native {
|
||||
namespace {
|
||||
// Check foreach API restrictions
|
||||
// - Tensor lists must be non-empty.
|
||||
// - All tensors in all lists must have the same dtype.
|
||||
// - All TensorLists and ScalarLists must have the same number of elements.
|
||||
// - Corresponding tensors must have the same size.
|
||||
|
||||
void check_foreach_api_restrictions(TensorList tensors) {
|
||||
TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
|
||||
auto expected_dtype = tensors[0].dtype();
|
||||
|
|
@ -17,7 +13,7 @@ void check_foreach_api_restrictions(TensorList tensors) {
|
|||
}
|
||||
}
|
||||
|
||||
void check_foreach_api_restrictions(TensorList tensors, ArrayRef<Scalar> scalars) {
|
||||
void check_foreach_api_restrictions(TensorList tensors, ArrayRef<double> scalars) {
|
||||
check_foreach_api_restrictions(tensors);
|
||||
TORCH_CHECK(tensors.size() == scalars.size(), "Tensor list must have same number of elements as scalar list.");
|
||||
}
|
||||
|
|
@ -53,7 +49,7 @@ void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, Te
|
|||
}
|
||||
}
|
||||
|
||||
void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef<Scalar> scalars) {
|
||||
void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef<double> scalars) {
|
||||
check_foreach_api_restrictions(tensors1, tensors2, tensors3);
|
||||
TORCH_CHECK(tensors1.size() == scalars.size(), "Tensor list must have same number of elements as scalar list, got ", tensors1.size(), " and ", scalars.size());
|
||||
}
|
||||
|
|
@ -89,8 +85,21 @@ bool has_same_attributes(Device expected_device, TensorList tensors) {
|
|||
}
|
||||
|
||||
bool will_promote_tensor(const Tensor& tensor, Scalar scalar) {
|
||||
auto result_dtype = at::result_type(tensor, scalar);
|
||||
return result_dtype != tensor.scalar_type();
|
||||
// complex scalar + integral or boolean tensor will result in complex tensor
|
||||
if (scalar.isComplex() && at::isIntegralType(tensor.scalar_type(), /*includeBool*/ true)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// float scalar + integral or boolean tensor will result in float tensor
|
||||
if (scalar.isFloatingPoint() && at::isIntegralType(tensor.scalar_type(), /*includeBool*/ true)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// integral scalar + boolean tensor will result in integral tensor
|
||||
if (scalar.isIntegral(/*includeBool*/ false) && tensor.dtype() == at::kBool) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool can_use_fast_route(TensorList tensors) {
|
||||
|
|
@ -119,7 +128,7 @@ bool can_use_fast_route(TensorList tensors, Scalar scalar) {
|
|||
return false;
|
||||
}
|
||||
|
||||
if (will_promote_tensor(t, scalar)) {
|
||||
if (!will_promote_tensor(t, scalar)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
@ -128,18 +137,8 @@ bool can_use_fast_route(TensorList tensors, Scalar scalar) {
|
|||
#endif
|
||||
}
|
||||
|
||||
bool can_use_fast_route(TensorList tensors, ArrayRef<Scalar> scalars) {
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
return false;
|
||||
#else
|
||||
for (int i = 0; i < tensors.size(); i++) {
|
||||
if (will_promote_tensor(tensors[i], scalars[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
#endif
|
||||
bool can_use_fast_route(TensorList tensors, ArrayRef<double> scalars) {
|
||||
return can_use_fast_route(tensors);
|
||||
}
|
||||
|
||||
bool can_use_fast_route(TensorList tensors1, TensorList tensors2) {
|
||||
|
|
@ -167,7 +166,7 @@ bool can_use_fast_route(TensorList tensors1, TensorList tensors2, Scalar scalar)
|
|||
return false;
|
||||
}
|
||||
|
||||
if (will_promote_tensor(tensors1[i], scalar)) {
|
||||
if (!will_promote_tensor(tensors1[i], scalar)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
@ -201,7 +200,7 @@ bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList ten
|
|||
return false;
|
||||
}
|
||||
|
||||
if (will_promote_tensor(tensors1[i], scalar)) {
|
||||
if (!will_promote_tensor(tensors1[i], scalar)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
@ -210,7 +209,7 @@ bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList ten
|
|||
#endif
|
||||
}
|
||||
|
||||
bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef<Scalar> scalars) {
|
||||
bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef<double> scalars) {
|
||||
return can_use_fast_route(tensors1, tensors2, tensors3);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
namespace at { namespace native {
|
||||
|
||||
template<template<class> class Op>
|
||||
std::vector<Tensor> foreach_binary_op(TensorList tensors, at::ArrayRef<Scalar> scalars) {
|
||||
std::vector<Tensor> foreach_binary_op(TensorList tensors, at::ArrayRef<double> scalars) {
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists;
|
||||
std::vector<at::Tensor> vec_res;
|
||||
vec_res.reserve(tensors.size());
|
||||
|
|
@ -18,51 +18,52 @@ std::vector<Tensor> foreach_binary_op(TensorList tensors, at::ArrayRef<Scalar> s
|
|||
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_binary_op_scalarlist_cuda", [&]() {
|
||||
using opmath_t = get_opmath_t<scalar_t>::opmath_t;
|
||||
multi_tensor_apply<2, opmath_t>(tensor_lists,
|
||||
scalars,
|
||||
BinaryOpScalarListFunctor<scalar_t,
|
||||
/* depth */ 2,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 1>(),
|
||||
Op<opmath_t>());
|
||||
multi_tensor_apply<2>(tensor_lists,
|
||||
scalars,
|
||||
BinaryOpScalarListFunctor<scalar_t,
|
||||
/* depth */ 2,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 1>(),
|
||||
|
||||
Op<opmath_t>());
|
||||
});
|
||||
return tensor_lists[1];
|
||||
}
|
||||
|
||||
template<template<class> class Op>
|
||||
void foreach_binary_op_(TensorList tensors, at::ArrayRef<Scalar> scalars) {
|
||||
void foreach_binary_op_(TensorList tensors, at::ArrayRef<double> scalars) {
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists;
|
||||
tensor_lists.emplace_back(tensors.vec());
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_binary_op_scalarlist_cuda_", [&]() {
|
||||
using opmath_t = get_opmath_t<scalar_t>::opmath_t;
|
||||
multi_tensor_apply<1, opmath_t>(tensor_lists,
|
||||
scalars,
|
||||
BinaryOpScalarListFunctor<scalar_t,
|
||||
/* depth */ 1,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 0>(),
|
||||
Op<opmath_t>());
|
||||
multi_tensor_apply<1>(tensor_lists,
|
||||
scalars,
|
||||
BinaryOpScalarListFunctor<scalar_t,
|
||||
/* depth */ 1,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 0>(),
|
||||
Op<opmath_t>());
|
||||
});
|
||||
}
|
||||
|
||||
#define FOREACH_BINARY_OP_SCALARLIST(NAME, OP) \
|
||||
void foreach_tensor_##NAME##_scalarlist_kernel_cuda_(TensorList tensors, at::ArrayRef<Scalar> scalars) { \
|
||||
check_foreach_api_restrictions(tensors, scalars); \
|
||||
if (!can_use_fast_route(tensors, scalars)) { \
|
||||
return at::native::foreach_tensor_##NAME##_scalarlist_kernel_slow_(tensors, scalars); \
|
||||
} \
|
||||
\
|
||||
foreach_binary_op_<OP>(tensors, scalars); \
|
||||
} \
|
||||
\
|
||||
std::vector<Tensor> foreach_tensor_##NAME##_scalarlist_kernel_cuda(TensorList tensors, at::ArrayRef<Scalar> scalars) { \
|
||||
check_foreach_api_restrictions(tensors, scalars); \
|
||||
if (!can_use_fast_route(tensors, scalars)) { \
|
||||
return at::native::foreach_tensor_##NAME##_scalarlist_kernel_slow(tensors, scalars); \
|
||||
} \
|
||||
\
|
||||
return foreach_binary_op<OP>(tensors, scalars); \
|
||||
#define FOREACH_BINARY_OP_SCALARLIST(NAME, OP) \
|
||||
void foreach_tensor_##NAME##_scalarlist_kernel_cuda_(TensorList tensors, at::ArrayRef<double> scalars) { \
|
||||
check_foreach_api_restrictions(tensors, scalars); \
|
||||
if (!can_use_fast_route(tensors, scalars)) { \
|
||||
return at::native::foreach_tensor_##NAME##_scalarlist_kernel_slow_(tensors, scalars); \
|
||||
} \
|
||||
\
|
||||
foreach_binary_op_<OP>(tensors, scalars); \
|
||||
} \
|
||||
\
|
||||
std::vector<Tensor> foreach_tensor_##NAME##_scalarlist_kernel_cuda(TensorList tensors, at::ArrayRef<double> scalars) { \
|
||||
check_foreach_api_restrictions(tensors, scalars); \
|
||||
if (!can_use_fast_route(tensors, scalars)) { \
|
||||
return at::native::foreach_tensor_##NAME##_scalarlist_kernel_slow(tensors, scalars); \
|
||||
} \
|
||||
\
|
||||
return foreach_binary_op<OP>(tensors, scalars); \
|
||||
}
|
||||
|
||||
FOREACH_BINARY_OP_SCALARLIST(add, std::plus);
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ void foreach_pointwise_op_(TensorList input, TensorList tensors1, TensorList ten
|
|||
}
|
||||
|
||||
template<template<class> class Op>
|
||||
void foreach_pointwise_op_(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<Scalar> scalars) {
|
||||
void foreach_pointwise_op_(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<double> scalars) {
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists;
|
||||
tensor_lists.reserve(3);
|
||||
tensor_lists.emplace_back(input.vec());
|
||||
|
|
@ -62,18 +62,18 @@ void foreach_pointwise_op_(TensorList input, TensorList tensors1, TensorList ten
|
|||
|
||||
AT_DISPATCH_ALL_TYPES_AND(kHalf, input[0].scalar_type(), "foreach_pointwise_op__cuda", [&]() {
|
||||
using opmath_t = get_opmath_t<scalar_t>::opmath_t;
|
||||
multi_tensor_apply<3, opmath_t>(tensor_lists,
|
||||
scalars,
|
||||
PointwiseOpScalarListFunctor<scalar_t,
|
||||
/* depth */ 3,
|
||||
/* r_args_depth */ 3,
|
||||
/* res_arg_index */ 0>(),
|
||||
Op<opmath_t>());
|
||||
multi_tensor_apply<3>(tensor_lists,
|
||||
scalars,
|
||||
PointwiseOpScalarListFunctor<scalar_t,
|
||||
/* depth */ 3,
|
||||
/* r_args_depth */ 3,
|
||||
/* res_arg_index */ 0>(),
|
||||
Op<opmath_t>());
|
||||
});
|
||||
}
|
||||
|
||||
template<template<class> class Op>
|
||||
std::vector<Tensor> foreach_pointwise_op(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<Scalar> scalars) {
|
||||
std::vector<Tensor> foreach_pointwise_op(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<double> scalars) {
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists;
|
||||
tensor_lists.reserve(4);
|
||||
std::vector<at::Tensor> vec_res;
|
||||
|
|
@ -89,13 +89,13 @@ std::vector<Tensor> foreach_pointwise_op(TensorList input, TensorList tensors1,
|
|||
|
||||
AT_DISPATCH_ALL_TYPES_AND(kHalf, input[0].scalar_type(), "foreach_pointwise_op_cuda", [&]() {
|
||||
using opmath_t = get_opmath_t<scalar_t>::opmath_t;
|
||||
multi_tensor_apply<4, opmath_t>(tensor_lists,
|
||||
scalars,
|
||||
PointwiseOpScalarListFunctor<scalar_t,
|
||||
/* depth */ 4,
|
||||
/* r_args_depth */ 3,
|
||||
/* res_arg_index */ 3>(),
|
||||
Op<opmath_t>());
|
||||
multi_tensor_apply<4>(tensor_lists,
|
||||
scalars,
|
||||
PointwiseOpScalarListFunctor<scalar_t,
|
||||
/* depth */ 4,
|
||||
/* r_args_depth */ 3,
|
||||
/* res_arg_index */ 3>(),
|
||||
Op<opmath_t>());
|
||||
});
|
||||
|
||||
return tensor_lists[3];
|
||||
|
|
@ -124,7 +124,7 @@ void foreach_tensor_##NAME##_scalar_cuda_(TensorList input, TensorList tensors1,
|
|||
|
||||
|
||||
#define FOREACH_POINTWISE_OP_SCALARLIST(NAME, OP) \
|
||||
std::vector<Tensor> foreach_tensor_##NAME##_scalarlist_cuda(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<Scalar> scalars) { \
|
||||
std::vector<Tensor> foreach_tensor_##NAME##_scalarlist_cuda(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<double> scalars) { \
|
||||
check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \
|
||||
\
|
||||
if (!can_use_fast_route(input, tensors1, tensors2, scalars)) { \
|
||||
|
|
@ -134,7 +134,7 @@ std::vector<Tensor> foreach_tensor_##NAME##_scalarlist_cuda(TensorList input, Te
|
|||
return foreach_pointwise_op<OP>(input, tensors1, tensors2, scalars); \
|
||||
} \
|
||||
\
|
||||
void foreach_tensor_##NAME##_scalarlist_cuda_(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<Scalar> scalars) { \
|
||||
void foreach_tensor_##NAME##_scalarlist_cuda_(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<double> scalars) { \
|
||||
check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \
|
||||
\
|
||||
if (!can_use_fast_route(input, tensors1, tensors2, scalars)) { \
|
||||
|
|
|
|||
|
|
@ -56,10 +56,10 @@ multi_tensor_apply_kernel(
|
|||
callable(kChunkSize, tensorListMeta, args...);
|
||||
}
|
||||
|
||||
template<int depth, typename scalar_T, typename T, typename... ArgTypes>
|
||||
template<int depth, typename T, typename... ArgTypes>
|
||||
void multi_tensor_apply(
|
||||
std::vector<std::vector<at::Tensor>>& tensor_lists,
|
||||
at::ArrayRef<Scalar> scalars,
|
||||
at::ArrayRef<double> scalars,
|
||||
T callable,
|
||||
ArgTypes... args) {
|
||||
TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists has to match the depth.");
|
||||
|
|
@ -71,7 +71,7 @@ void multi_tensor_apply(
|
|||
int loc_tensor_info = 0;
|
||||
for(size_t t = 0; t < n_tensors; t++) {
|
||||
|
||||
tensorListMeta.scalar_vals[loc_tensor_info] = scalars[t].to<scalar_T>();
|
||||
tensorListMeta.scalar_vals[loc_tensor_info] = scalars[t];
|
||||
|
||||
tensorListMeta.numel_for_tensor[loc_tensor_info] = tensor_lists[0][t].numel();
|
||||
for (int d = 0; d < depth; d++) {
|
||||
|
|
@ -115,6 +115,7 @@ void multi_tensor_apply(
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
template<int depth, typename T, typename... ArgTypes>
|
||||
void multi_tensor_apply(
|
||||
std::vector<std::vector<at::Tensor>>& tensor_lists,
|
||||
|
|
|
|||
|
|
@ -6713,49 +6713,49 @@
|
|||
CPU: foreach_tensor_div_list_kernel_slow_
|
||||
CUDA: foreach_tensor_div_list_kernel_cuda_
|
||||
|
||||
- func: _foreach_add.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[]
|
||||
- func: _foreach_add.ScalarList(Tensor[] tensors, float[] scalars) -> Tensor[]
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_add_scalarlist_kernel_slow
|
||||
CUDA: foreach_tensor_add_scalarlist_kernel_cuda
|
||||
|
||||
- func: _foreach_add_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()
|
||||
- func: _foreach_add_.ScalarList(Tensor(a!)[] self, float[] scalars) -> ()
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_add_scalarlist_kernel_slow_
|
||||
CUDA: foreach_tensor_add_scalarlist_kernel_cuda_
|
||||
|
||||
- func: _foreach_sub.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[]
|
||||
- func: _foreach_sub.ScalarList(Tensor[] tensors, float[] scalars) -> Tensor[]
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_sub_scalarlist_kernel_slow
|
||||
CUDA: foreach_tensor_sub_scalarlist_kernel_cuda
|
||||
|
||||
- func: _foreach_sub_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()
|
||||
- func: _foreach_sub_.ScalarList(Tensor(a!)[] self, float[] scalars) -> ()
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_sub_scalarlist_kernel_slow_
|
||||
CUDA: foreach_tensor_sub_scalarlist_kernel_cuda_
|
||||
|
||||
- func: _foreach_div.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[]
|
||||
- func: _foreach_div.ScalarList(Tensor[] tensors, float[] scalars) -> Tensor[]
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_div_scalarlist_kernel_slow
|
||||
CUDA: foreach_tensor_div_scalarlist_kernel_cuda
|
||||
|
||||
- func: _foreach_div_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()
|
||||
- func: _foreach_div_.ScalarList(Tensor(a!)[] self, float[] scalars) -> ()
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_div_scalarlist_kernel_slow_
|
||||
CUDA: foreach_tensor_div_scalarlist_kernel_cuda_
|
||||
|
||||
- func: _foreach_mul.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[]
|
||||
- func: _foreach_mul.ScalarList(Tensor[] tensors, float[] scalars) -> Tensor[]
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_mul_scalarlist_kernel_slow
|
||||
CUDA: foreach_tensor_mul_scalarlist_kernel_cuda
|
||||
|
||||
- func: _foreach_mul_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()
|
||||
- func: _foreach_mul_.ScalarList(Tensor(a!)[] self, float[] scalars) -> ()
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_mul_scalarlist_kernel_slow_
|
||||
|
|
@ -7115,13 +7115,13 @@
|
|||
CPU: foreach_tensor_addcmul_scalar_slow_
|
||||
CUDA: foreach_tensor_addcmul_scalar_cuda_
|
||||
|
||||
- func: _foreach_addcdiv_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> ()
|
||||
- func: _foreach_addcdiv_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, float[] scalars) -> ()
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_addcdiv_scalarlist_slow_
|
||||
CUDA: foreach_tensor_addcdiv_scalarlist_cuda_
|
||||
|
||||
- func: _foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> ()
|
||||
- func: _foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, float[] scalars) -> ()
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_addcmul_scalarlist_slow_
|
||||
|
|
@ -7139,13 +7139,13 @@
|
|||
CPU: foreach_tensor_addcmul_scalar_slow
|
||||
CUDA: foreach_tensor_addcmul_scalar_cuda
|
||||
|
||||
- func: _foreach_addcdiv.ScalarList(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[]
|
||||
- func: _foreach_addcdiv.ScalarList(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, float[] scalars) -> Tensor[]
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_addcdiv_scalarlist_slow
|
||||
CUDA: foreach_tensor_addcdiv_scalarlist_cuda
|
||||
|
||||
- func: _foreach_addcmul.ScalarList(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[]
|
||||
- func: _foreach_addcmul.ScalarList(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, float[] scalars) -> Tensor[]
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_addcmul_scalarlist_slow
|
||||
|
|
|
|||
|
|
@ -59,19 +59,6 @@ allow_list = [
|
|||
("aten::fake_quantize_per_tensor_affine_backward", datetime.date(2021, 2, 20)),
|
||||
("aten::fake_quantize_per_channel_affine_backward", datetime.date(2021, 2, 20)),
|
||||
("aten::rowwise_prune", datetime.date(9999, 1, 1)),
|
||||
("aten::_foreach_mul_", datetime.date(2021, 3, 2)),
|
||||
("aten::_foreach_addcdiv_", datetime.date(2021, 3, 2)),
|
||||
("aten::_foreach_div", datetime.date(2021, 3, 2)),
|
||||
("aten::_foreach_addcmul_", datetime.date(2021, 3, 2)),
|
||||
("aten::_foreach_sub", datetime.date(2021, 3, 2)),
|
||||
("aten::_foreach_add", datetime.date(2021, 3, 2)),
|
||||
("aten::_foreach_sub_", datetime.date(2021, 3, 2)),
|
||||
("aten::_foreach_add_", datetime.date(2021, 3, 2)),
|
||||
("aten::_foreach_mul", datetime.date(2021, 3, 2)),
|
||||
("aten::_foreach_div_", datetime.date(2021, 3, 2)),
|
||||
("aten::_foreach_addcdiv", datetime.date(2021, 3, 2)),
|
||||
("aten::_foreach_addcmul", datetime.date(2021, 3, 2)),
|
||||
|
||||
]
|
||||
|
||||
def allow_listed(schema, allow_list):
|
||||
|
|
|
|||
|
|
@ -7,11 +7,25 @@ from torch._six import inf, nan
|
|||
N_values = [20] if not TEST_WITH_SLOW else [30, 300]
|
||||
|
||||
class TestForeach(TestCase):
|
||||
bin_ops = [
|
||||
(torch._foreach_add, torch._foreach_add_, torch.add),
|
||||
(torch._foreach_sub, torch._foreach_sub_, torch.sub),
|
||||
(torch._foreach_mul, torch._foreach_mul_, torch.mul),
|
||||
(torch._foreach_div, torch._foreach_div_, torch.div),
|
||||
foreach_bin_ops = [
|
||||
torch._foreach_add,
|
||||
torch._foreach_sub,
|
||||
torch._foreach_mul,
|
||||
torch._foreach_div,
|
||||
]
|
||||
|
||||
foreach_bin_ops_ = [
|
||||
torch._foreach_add_,
|
||||
torch._foreach_sub_,
|
||||
torch._foreach_mul_,
|
||||
torch._foreach_div_,
|
||||
]
|
||||
|
||||
torch_bin_ops = [
|
||||
torch.add,
|
||||
torch.sub,
|
||||
torch.mul,
|
||||
torch.div,
|
||||
]
|
||||
|
||||
unary_ops = [
|
||||
|
|
@ -128,7 +142,7 @@ class TestForeach(TestCase):
|
|||
op(tensors, tensors1, tensors2, [2 for _ in range(N)])
|
||||
|
||||
def _test_bin_op_list_alpha(self, device, dtype, foreach_op, foreach_op_, torch_op):
|
||||
for N in N_values:
|
||||
for N in [30, 300]:
|
||||
tensors1 = self._get_test_data(device, dtype, N)
|
||||
tensors2 = self._get_test_data(device, dtype, N)
|
||||
alpha = 2
|
||||
|
|
@ -338,7 +352,9 @@ class TestForeach(TestCase):
|
|||
@dtypes(*torch.testing.get_all_dtypes())
|
||||
def test_int_scalar(self, device, dtype):
|
||||
for N in N_values:
|
||||
for foreach_bin_op, foreach_bin_op_, torch_bin_op in self.bin_ops:
|
||||
for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
|
||||
self.foreach_bin_ops_,
|
||||
self.torch_bin_ops):
|
||||
tensors = self._get_test_data(device, dtype, N)
|
||||
scalar = 3
|
||||
expected = [torch_bin_op(t, scalar) for t in tensors]
|
||||
|
|
@ -378,13 +394,15 @@ class TestForeach(TestCase):
|
|||
@dtypes(*torch.testing.get_all_dtypes())
|
||||
def test_int_scalarlist(self, device, dtype):
|
||||
for N in N_values:
|
||||
for foreach_bin_op, foreach_bin_op_, torch_bin_op in self.bin_ops:
|
||||
for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
|
||||
self.foreach_bin_ops_,
|
||||
self.torch_bin_ops):
|
||||
tensors = self._get_test_data(device, dtype, N)
|
||||
scalars = [1 for _ in range(N)]
|
||||
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
|
||||
|
||||
# we dont support bool and complex types on CUDA for now
|
||||
if dtype in torch.testing.get_all_complex_dtypes() and self.device_type == 'cuda':
|
||||
if (dtype in torch.testing.get_all_complex_dtypes() or dtype == torch.bool) and self.device_type == 'cuda':
|
||||
with self.assertRaisesRegex(RuntimeError, "not implemented for"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
|
||||
|
|
@ -395,41 +413,36 @@ class TestForeach(TestCase):
|
|||
res = foreach_bin_op(tensors, scalars)
|
||||
|
||||
if dtype == torch.bool:
|
||||
self.assertEqual(res, [torch_bin_op(t, s) for t, s in zip(tensors, scalars)])
|
||||
self.assertEqual(res, [torch_bin_op(t.to(torch.float32), s) for t, s in zip(tensors, scalars)])
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"):
|
||||
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired output type"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
return
|
||||
|
||||
# test out of place
|
||||
if dtype in torch.testing.integral_types():
|
||||
if self.device_type == 'cpu':
|
||||
self.assertEqual(res, expected)
|
||||
self.assertEqual(res, [e.to(torch.float32) for e in expected])
|
||||
else:
|
||||
# TODO[type promotion]: Fix once type promotion is enabled.
|
||||
self.assertEqual(res, [e.to(dtype) for e in expected])
|
||||
else:
|
||||
self.assertEqual(res, expected)
|
||||
|
||||
# test in-place
|
||||
if dtype in torch.testing.floating_types() and self.device_type == 'cpu':
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
if dtype in torch.testing.integral_types() and self.device_type == 'cpu':
|
||||
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired output type"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
return
|
||||
else:
|
||||
if foreach_bin_op_ == torch._foreach_div_ and \
|
||||
dtype in torch.testing.integral_types() and \
|
||||
self.device_type == 'cpu':
|
||||
with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
else:
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
self.assertEqual(res, tensors)
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
self.assertEqual(res, tensors)
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(*torch.testing.get_all_dtypes())
|
||||
def test_float_scalar(self, device, dtype):
|
||||
for N in N_values:
|
||||
for foreach_bin_op, foreach_bin_op_, torch_bin_op in self.bin_ops:
|
||||
for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
|
||||
self.foreach_bin_ops_,
|
||||
self.torch_bin_ops):
|
||||
tensors = self._get_test_data(device, dtype, N)
|
||||
scalar = 3.3
|
||||
|
||||
|
|
@ -471,38 +484,17 @@ class TestForeach(TestCase):
|
|||
@dtypes(*torch.testing.get_all_dtypes())
|
||||
def test_float_scalarlist(self, device, dtype):
|
||||
for N in N_values:
|
||||
for foreach_bin_op, foreach_bin_op_, torch_bin_op in self.bin_ops:
|
||||
for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
|
||||
self.foreach_bin_ops_,
|
||||
self.torch_bin_ops):
|
||||
tensors = self._get_test_data(device, dtype, N)
|
||||
scalars = [1.1 for _ in range(N)]
|
||||
|
||||
# Bool case
|
||||
if dtype == torch.bool:
|
||||
if foreach_bin_op == torch._foreach_sub:
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with a bool tensor"):
|
||||
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with a bool tensor"):
|
||||
res = foreach_bin_op(tensors, scalars)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with a bool tensor"):
|
||||
[t.sub_(scalar) for t, scalar in zip(tensors, scalars)]
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with a bool tensor"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
continue
|
||||
|
||||
res = foreach_bin_op(tensors, scalars)
|
||||
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
|
||||
self.assertEqual(res, expected)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired output type"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
continue
|
||||
|
||||
# If incoming dtype is float16 or bfloat16, runs in float32 and casts output back to dtype.
|
||||
control_dtype = torch.float32 if (self.device_type == 'cuda' and
|
||||
(dtype is torch.float16 or dtype is torch.bfloat16)) else dtype
|
||||
expected = [torch_bin_op(t.to(dtype=control_dtype), s) for t, s in zip(tensors, scalars)]
|
||||
expected = [torch_bin_op(t.to(dtype=control_dtype),
|
||||
s) for t, s in zip(tensors, scalars)]
|
||||
if (dtype is torch.float16 or dtype is torch.bfloat16):
|
||||
expected = [e.to(dtype=dtype) for e in expected]
|
||||
|
||||
|
|
@ -517,11 +509,21 @@ class TestForeach(TestCase):
|
|||
|
||||
res = foreach_bin_op(tensors, scalars)
|
||||
|
||||
if dtype in torch.testing.integral_types() and self.device_type == 'cuda':
|
||||
self.assertEqual(res, expected)
|
||||
if dtype == torch.bool:
|
||||
# see TODO[Fix scalar list]
|
||||
self.assertEqual(res, [torch_bin_op(t.to(torch.float32), s) for t, s in zip(tensors, scalars)])
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired output type"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
continue
|
||||
return
|
||||
|
||||
if dtype in torch.testing.integral_types() and self.device_type == 'cuda':
|
||||
# see TODO[Fix scalar list]
|
||||
self.assertEqual(res, [e.to(dtype) for e in expected])
|
||||
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
self.assertEqual(tensors, res)
|
||||
return
|
||||
else:
|
||||
if (dtype is torch.float16 or dtype is torch.bfloat16) and TEST_WITH_ROCM:
|
||||
self.assertEqual(res, expected, atol=1.e-3, rtol=self.dtype_precisions[dtype][0])
|
||||
|
|
@ -531,7 +533,7 @@ class TestForeach(TestCase):
|
|||
if dtype in torch.testing.integral_types() and self.device_type == "cpu":
|
||||
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired output type"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
continue
|
||||
return
|
||||
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
if (dtype is torch.float16 or dtype is torch.bfloat16) and TEST_WITH_ROCM:
|
||||
|
|
@ -543,29 +545,34 @@ class TestForeach(TestCase):
|
|||
@dtypes(*torch.testing.get_all_dtypes())
|
||||
def test_complex_scalar(self, device, dtype):
|
||||
for N in N_values:
|
||||
for foreach_bin_op, foreach_bin_op_, torch_bin_op in self.bin_ops:
|
||||
for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
|
||||
self.foreach_bin_ops_,
|
||||
self.torch_bin_ops):
|
||||
tensors = self._get_test_data(device, dtype, N)
|
||||
scalar = 3 + 5j
|
||||
expected = [torch_bin_op(t, scalar) for t in tensors]
|
||||
|
||||
# Bool case
|
||||
if dtype == torch.bool:
|
||||
if foreach_bin_op == torch._foreach_sub:
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator,"):
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with two bool"):
|
||||
foreach_bin_op_(tensors, scalar)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator,"):
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with two bool"):
|
||||
foreach_bin_op(tensors, scalar)
|
||||
continue
|
||||
return
|
||||
|
||||
if dtype in torch.testing.get_all_fp_dtypes(include_half=True, include_bfloat16=True) and \
|
||||
self.device_type == 'cuda':
|
||||
with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"):
|
||||
foreach_bin_op_(tensors, scalar)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"):
|
||||
foreach_bin_op(tensors, scalar)
|
||||
return
|
||||
|
||||
res = foreach_bin_op(tensors, scalar)
|
||||
expected = [torch_bin_op(t, scalar) for t in tensors]
|
||||
self.assertEqual(res, expected)
|
||||
|
||||
if dtype in torch.testing.get_all_fp_dtypes() and self.device_type == 'cuda':
|
||||
with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"):
|
||||
foreach_bin_op_(tensors, scalar)
|
||||
continue
|
||||
|
||||
if dtype not in [torch.complex64, torch.complex128]:
|
||||
with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"):
|
||||
foreach_bin_op_(tensors, scalar)
|
||||
|
|
@ -573,48 +580,38 @@ class TestForeach(TestCase):
|
|||
foreach_bin_op_(tensors, scalar)
|
||||
self.assertEqual(res, tensors)
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(*torch.testing.get_all_dtypes())
|
||||
def test_complex_scalarlist(self, device, dtype):
|
||||
for N in N_values:
|
||||
for foreach_bin_op, foreach_bin_op_, torch_bin_op in self.bin_ops:
|
||||
for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
|
||||
self.foreach_bin_ops_,
|
||||
self.torch_bin_ops):
|
||||
tensors = self._get_test_data(device, dtype, N)
|
||||
scalars = [3 + 5j for _ in range(N)]
|
||||
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
|
||||
|
||||
# Bool case
|
||||
if dtype == torch.bool:
|
||||
if foreach_bin_op == torch._foreach_sub:
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with two bool"):
|
||||
foreach_bin_op_(tensors, scalar)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator"):
|
||||
foreach_bin_op(tensors, scalars)
|
||||
continue
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with two bool"):
|
||||
foreach_bin_op(tensors, scalar)
|
||||
return
|
||||
|
||||
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
|
||||
if dtype in [torch.complex64, torch.complex128] and self.device_type == "cuda":
|
||||
with self.assertRaisesRegex(RuntimeError, "not implemented for"):
|
||||
res = foreach_bin_op(tensors, scalars)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "not implemented for"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
continue
|
||||
else:
|
||||
with self.assertRaisesRegex(TypeError, "argument 'scalars' must be tuple of floats"):
|
||||
res = foreach_bin_op(tensors, scalars)
|
||||
self.assertEqual(res, expected)
|
||||
|
||||
if dtype not in [torch.complex64, torch.complex128]:
|
||||
with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
else:
|
||||
with self.assertRaisesRegex(TypeError, "argument 'scalars' must be tuple of floats"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
self.assertEqual(res, tensors)
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(*torch.testing.get_all_dtypes())
|
||||
def test_bool_scalar(self, device, dtype):
|
||||
for N in N_values:
|
||||
for foreach_bin_op, foreach_bin_op_, torch_bin_op in self.bin_ops:
|
||||
for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
|
||||
self.foreach_bin_ops_,
|
||||
self.torch_bin_ops):
|
||||
tensors = self._get_test_data(device, dtype, N)
|
||||
scalar = True
|
||||
|
||||
|
|
@ -663,7 +660,9 @@ class TestForeach(TestCase):
|
|||
@dtypes(*torch.testing.get_all_dtypes())
|
||||
def test_bool_scalarlist(self, device, dtype):
|
||||
for N in N_values:
|
||||
for foreach_bin_op, foreach_bin_op_, torch_bin_op in self.bin_ops:
|
||||
for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
|
||||
self.foreach_bin_ops_,
|
||||
self.torch_bin_ops):
|
||||
tensors = self._get_test_data(device, dtype, N)
|
||||
scalars = [True for _ in range(N)]
|
||||
|
||||
|
|
@ -677,24 +676,20 @@ class TestForeach(TestCase):
|
|||
return
|
||||
else:
|
||||
if foreach_bin_op == torch._foreach_sub:
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with two bool tensors"):
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with a bool tensor"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with two bool tensors"):
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with a bool tensor"):
|
||||
foreach_bin_op(tensors, scalars)
|
||||
else:
|
||||
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
|
||||
res = foreach_bin_op(tensors, scalars)
|
||||
self.assertEqual(res, expected)
|
||||
|
||||
if foreach_bin_op_ == torch._foreach_div_:
|
||||
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
else:
|
||||
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
self.assertEqual(res, tensors)
|
||||
|
||||
res = foreach_bin_op(tensors, scalars)
|
||||
for r in res:
|
||||
self.assertTrue(r.dtype == torch.float32)
|
||||
else:
|
||||
# we dont support complex types on CUDA for now
|
||||
# we dont support bool and complex types on CUDA for now
|
||||
if (dtype in torch.testing.get_all_complex_dtypes()) and self.device_type == 'cuda':
|
||||
with self.assertRaisesRegex(RuntimeError, "not implemented for"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
|
|
@ -703,44 +698,57 @@ class TestForeach(TestCase):
|
|||
foreach_bin_op(tensors, scalars)
|
||||
return
|
||||
|
||||
if foreach_bin_op == torch._foreach_sub and self.device_type == "cpu":
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with a bool tensor"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
if foreach_bin_op == torch._foreach_sub:
|
||||
if self.device_type == "cpu":
|
||||
# see TODO[Fix scalar list]
|
||||
res = foreach_bin_op(tensors, scalars)
|
||||
if dtype in torch.testing.integral_types():
|
||||
self.assertEqual(res, [r.to(torch.float32) for r in [torch_bin_op(t, 1) for t in tensors]])
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with a bool tensor"):
|
||||
foreach_bin_op(tensors, scalars)
|
||||
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the "):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
else:
|
||||
self.assertEqual(res, [torch_bin_op(t, 1) for t in tensors])
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
self.assertEqual(res, tensors)
|
||||
else:
|
||||
# see TODO[Fix scalar list]
|
||||
res = foreach_bin_op(tensors, scalars)
|
||||
if dtype in torch.testing.integral_types():
|
||||
self.assertEqual(res, [r.to(dtype) for r in [torch_bin_op(t, 1) for t in tensors]])
|
||||
else:
|
||||
self.assertEqual(res, [torch_bin_op(t, 1) for t in tensors])
|
||||
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
self.assertEqual(res, tensors)
|
||||
else:
|
||||
if self.device_type == "cpu":
|
||||
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
|
||||
res = foreach_bin_op(tensors, scalars)
|
||||
|
||||
self.assertEqual(res, expected)
|
||||
# see TODO[Fix scalar list]
|
||||
if dtype in torch.testing.integral_types():
|
||||
self.assertEqual(res, [e.to(torch.float32) for e in expected])
|
||||
else:
|
||||
self.assertEqual(res, expected)
|
||||
|
||||
if dtype in torch.testing.integral_types() and foreach_bin_op_ == torch._foreach_div_:
|
||||
if dtype in torch.testing.integral_types():
|
||||
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired "):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
else:
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
self.assertEqual(tensors, expected)
|
||||
else:
|
||||
if foreach_bin_op == torch._foreach_sub:
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with a bo"):
|
||||
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
|
||||
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
|
||||
res = foreach_bin_op(tensors, scalars)
|
||||
|
||||
res = foreach_bin_op(tensors, scalars)
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
self.assertEqual(res, tensors)
|
||||
if dtype in torch.testing.integral_types():
|
||||
self.assertEqual(res, [e.to(dtype) for e in expected])
|
||||
else:
|
||||
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
|
||||
res = foreach_bin_op(tensors, scalars)
|
||||
self.assertEqual(res, expected)
|
||||
|
||||
if dtype in torch.testing.integral_types():
|
||||
self.assertEqual(res, [e.to(dtype) for e in expected])
|
||||
else:
|
||||
self.assertEqual(res, expected)
|
||||
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
self.assertEqual(res, tensors)
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
self.assertEqual(res, tensors)
|
||||
|
||||
@dtypes(*torch.testing.get_all_dtypes())
|
||||
def test_add_with_different_size_tensors(self, device, dtype):
|
||||
|
|
@ -783,7 +791,7 @@ class TestForeach(TestCase):
|
|||
# Ops with list
|
||||
#
|
||||
def test_bin_op_list_error_cases(self, device):
|
||||
for bin_op, bin_op_, _ in self.bin_ops:
|
||||
for bin_op, bin_op_ in zip(self.foreach_bin_ops, self.foreach_bin_ops_):
|
||||
tensors1 = []
|
||||
tensors2 = []
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user