Revert D25074763: [WIP] Update foreach APIs to use scalar lists

Test Plan: revert-hammer

Differential Revision:
D25074763 (cce84b5ca5)

Original commit changeset: 155e3d2073a2

fbshipit-source-id: ef0d153e2740b50bd4a95f7a57c370bb5da46355
This commit is contained in:
Natalia Gimelshein 2021-02-03 16:59:30 -08:00 committed by Facebook GitHub Bot
parent d1bc1ab8ca
commit 443a431ac3
8 changed files with 230 additions and 234 deletions

View File

@ -26,7 +26,7 @@ std::vector<Tensor> foreach_tensor_##OP##_scalar_kernel_slow(TensorList tensors,
}
#define FOREACH_BINARY_OP_SCALARLIST(OP) \
void foreach_tensor_##OP##_scalarlist_kernel_slow_(TensorList tensors, at::ArrayRef<Scalar> scalars) { \
void foreach_tensor_##OP##_scalarlist_kernel_slow_(TensorList tensors, at::ArrayRef<double> scalars) { \
check_foreach_api_restrictions(tensors, scalars); \
\
for (const auto i : c10::irange(tensors.size())) { \
@ -34,7 +34,7 @@ void foreach_tensor_##OP##_scalarlist_kernel_slow_(TensorList tensors, at::Array
} \
} \
\
std::vector<Tensor> foreach_tensor_##OP##_scalarlist_kernel_slow(TensorList tensors, at::ArrayRef<Scalar> scalars) { \
std::vector<Tensor> foreach_tensor_##OP##_scalarlist_kernel_slow(TensorList tensors, at::ArrayRef<double> scalars) { \
check_foreach_api_restrictions(tensors, scalars); \
std::vector<Tensor> result; \
result.reserve(tensors.size()); \
@ -129,7 +129,7 @@ void foreach_tensor_##OP##_scalar_slow_(TensorList input, TensorList tensors1, T
} \
#define FOREACH_POINTWISE_OP_SCALARLIST(OP) \
std::vector<Tensor> foreach_tensor_##OP##_scalarlist_slow(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<Scalar> scalars) { \
std::vector<Tensor> foreach_tensor_##OP##_scalarlist_slow(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<double> scalars) { \
check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \
\
std::vector<Tensor> result; \
@ -140,7 +140,7 @@ std::vector<Tensor> foreach_tensor_##OP##_scalarlist_slow(TensorList input, Tens
return result; \
} \
\
void foreach_tensor_##OP##_scalarlist_slow_(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<Scalar> scalars) { \
void foreach_tensor_##OP##_scalarlist_slow_(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<double> scalars) { \
check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \
\
for(const auto i : c10::irange(input.size())) { \

View File

@ -4,11 +4,7 @@
namespace at {
namespace native {
namespace {
// Check foreach API restrictions
// - Tensor lists must be non-empty.
// - All tensors in all lists must have the same dtype.
// - All TensorLists and ScalarLists must have the same number of elements.
// - Corresponding tensors must have the same size.
void check_foreach_api_restrictions(TensorList tensors) {
TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
auto expected_dtype = tensors[0].dtype();
@ -17,7 +13,7 @@ void check_foreach_api_restrictions(TensorList tensors) {
}
}
void check_foreach_api_restrictions(TensorList tensors, ArrayRef<Scalar> scalars) {
void check_foreach_api_restrictions(TensorList tensors, ArrayRef<double> scalars) {
check_foreach_api_restrictions(tensors);
TORCH_CHECK(tensors.size() == scalars.size(), "Tensor list must have same number of elements as scalar list.");
}
@ -53,7 +49,7 @@ void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, Te
}
}
void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef<Scalar> scalars) {
void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef<double> scalars) {
check_foreach_api_restrictions(tensors1, tensors2, tensors3);
TORCH_CHECK(tensors1.size() == scalars.size(), "Tensor list must have same number of elements as scalar list, got ", tensors1.size(), " and ", scalars.size());
}
@ -89,8 +85,21 @@ bool has_same_attributes(Device expected_device, TensorList tensors) {
}
bool will_promote_tensor(const Tensor& tensor, Scalar scalar) {
auto result_dtype = at::result_type(tensor, scalar);
return result_dtype != tensor.scalar_type();
// complex scalar + integral or boolean tensor will result in complex tensor
if (scalar.isComplex() && at::isIntegralType(tensor.scalar_type(), /*includeBool*/ true)) {
return false;
}
// float scalar + integral or boolean tensor will result in float tensor
if (scalar.isFloatingPoint() && at::isIntegralType(tensor.scalar_type(), /*includeBool*/ true)) {
return false;
}
// integral scalar + boolean tensor will result in integral tensor
if (scalar.isIntegral(/*includeBool*/ false) && tensor.dtype() == at::kBool) {
return false;
}
return true;
}
bool can_use_fast_route(TensorList tensors) {
@ -119,7 +128,7 @@ bool can_use_fast_route(TensorList tensors, Scalar scalar) {
return false;
}
if (will_promote_tensor(t, scalar)) {
if (!will_promote_tensor(t, scalar)) {
return false;
}
}
@ -128,18 +137,8 @@ bool can_use_fast_route(TensorList tensors, Scalar scalar) {
#endif
}
bool can_use_fast_route(TensorList tensors, ArrayRef<Scalar> scalars) {
#ifdef __HIP_PLATFORM_HCC__
return false;
#else
for (int i = 0; i < tensors.size(); i++) {
if (will_promote_tensor(tensors[i], scalars[i])) {
return false;
}
}
return true;
#endif
bool can_use_fast_route(TensorList tensors, ArrayRef<double> scalars) {
return can_use_fast_route(tensors);
}
bool can_use_fast_route(TensorList tensors1, TensorList tensors2) {
@ -167,7 +166,7 @@ bool can_use_fast_route(TensorList tensors1, TensorList tensors2, Scalar scalar)
return false;
}
if (will_promote_tensor(tensors1[i], scalar)) {
if (!will_promote_tensor(tensors1[i], scalar)) {
return false;
}
}
@ -201,7 +200,7 @@ bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList ten
return false;
}
if (will_promote_tensor(tensors1[i], scalar)) {
if (!will_promote_tensor(tensors1[i], scalar)) {
return false;
}
}
@ -210,7 +209,7 @@ bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList ten
#endif
}
bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef<Scalar> scalars) {
bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef<double> scalars) {
return can_use_fast_route(tensors1, tensors2, tensors3);
}

View File

@ -5,7 +5,7 @@
namespace at { namespace native {
template<template<class> class Op>
std::vector<Tensor> foreach_binary_op(TensorList tensors, at::ArrayRef<Scalar> scalars) {
std::vector<Tensor> foreach_binary_op(TensorList tensors, at::ArrayRef<double> scalars) {
std::vector<std::vector<at::Tensor>> tensor_lists;
std::vector<at::Tensor> vec_res;
vec_res.reserve(tensors.size());
@ -18,51 +18,52 @@ std::vector<Tensor> foreach_binary_op(TensorList tensors, at::ArrayRef<Scalar> s
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_binary_op_scalarlist_cuda", [&]() {
using opmath_t = get_opmath_t<scalar_t>::opmath_t;
multi_tensor_apply<2, opmath_t>(tensor_lists,
scalars,
BinaryOpScalarListFunctor<scalar_t,
/* depth */ 2,
/* r_args_depth */ 1,
/* res_arg_index */ 1>(),
Op<opmath_t>());
multi_tensor_apply<2>(tensor_lists,
scalars,
BinaryOpScalarListFunctor<scalar_t,
/* depth */ 2,
/* r_args_depth */ 1,
/* res_arg_index */ 1>(),
Op<opmath_t>());
});
return tensor_lists[1];
}
template<template<class> class Op>
void foreach_binary_op_(TensorList tensors, at::ArrayRef<Scalar> scalars) {
void foreach_binary_op_(TensorList tensors, at::ArrayRef<double> scalars) {
std::vector<std::vector<at::Tensor>> tensor_lists;
tensor_lists.emplace_back(tensors.vec());
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_binary_op_scalarlist_cuda_", [&]() {
using opmath_t = get_opmath_t<scalar_t>::opmath_t;
multi_tensor_apply<1, opmath_t>(tensor_lists,
scalars,
BinaryOpScalarListFunctor<scalar_t,
/* depth */ 1,
/* r_args_depth */ 1,
/* res_arg_index */ 0>(),
Op<opmath_t>());
multi_tensor_apply<1>(tensor_lists,
scalars,
BinaryOpScalarListFunctor<scalar_t,
/* depth */ 1,
/* r_args_depth */ 1,
/* res_arg_index */ 0>(),
Op<opmath_t>());
});
}
#define FOREACH_BINARY_OP_SCALARLIST(NAME, OP) \
void foreach_tensor_##NAME##_scalarlist_kernel_cuda_(TensorList tensors, at::ArrayRef<Scalar> scalars) { \
check_foreach_api_restrictions(tensors, scalars); \
if (!can_use_fast_route(tensors, scalars)) { \
return at::native::foreach_tensor_##NAME##_scalarlist_kernel_slow_(tensors, scalars); \
} \
\
foreach_binary_op_<OP>(tensors, scalars); \
} \
\
std::vector<Tensor> foreach_tensor_##NAME##_scalarlist_kernel_cuda(TensorList tensors, at::ArrayRef<Scalar> scalars) { \
check_foreach_api_restrictions(tensors, scalars); \
if (!can_use_fast_route(tensors, scalars)) { \
return at::native::foreach_tensor_##NAME##_scalarlist_kernel_slow(tensors, scalars); \
} \
\
return foreach_binary_op<OP>(tensors, scalars); \
#define FOREACH_BINARY_OP_SCALARLIST(NAME, OP) \
void foreach_tensor_##NAME##_scalarlist_kernel_cuda_(TensorList tensors, at::ArrayRef<double> scalars) { \
check_foreach_api_restrictions(tensors, scalars); \
if (!can_use_fast_route(tensors, scalars)) { \
return at::native::foreach_tensor_##NAME##_scalarlist_kernel_slow_(tensors, scalars); \
} \
\
foreach_binary_op_<OP>(tensors, scalars); \
} \
\
std::vector<Tensor> foreach_tensor_##NAME##_scalarlist_kernel_cuda(TensorList tensors, at::ArrayRef<double> scalars) { \
check_foreach_api_restrictions(tensors, scalars); \
if (!can_use_fast_route(tensors, scalars)) { \
return at::native::foreach_tensor_##NAME##_scalarlist_kernel_slow(tensors, scalars); \
} \
\
return foreach_binary_op<OP>(tensors, scalars); \
}
FOREACH_BINARY_OP_SCALARLIST(add, std::plus);

View File

@ -53,7 +53,7 @@ void foreach_pointwise_op_(TensorList input, TensorList tensors1, TensorList ten
}
template<template<class> class Op>
void foreach_pointwise_op_(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<Scalar> scalars) {
void foreach_pointwise_op_(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<double> scalars) {
std::vector<std::vector<at::Tensor>> tensor_lists;
tensor_lists.reserve(3);
tensor_lists.emplace_back(input.vec());
@ -62,18 +62,18 @@ void foreach_pointwise_op_(TensorList input, TensorList tensors1, TensorList ten
AT_DISPATCH_ALL_TYPES_AND(kHalf, input[0].scalar_type(), "foreach_pointwise_op__cuda", [&]() {
using opmath_t = get_opmath_t<scalar_t>::opmath_t;
multi_tensor_apply<3, opmath_t>(tensor_lists,
scalars,
PointwiseOpScalarListFunctor<scalar_t,
/* depth */ 3,
/* r_args_depth */ 3,
/* res_arg_index */ 0>(),
Op<opmath_t>());
multi_tensor_apply<3>(tensor_lists,
scalars,
PointwiseOpScalarListFunctor<scalar_t,
/* depth */ 3,
/* r_args_depth */ 3,
/* res_arg_index */ 0>(),
Op<opmath_t>());
});
}
template<template<class> class Op>
std::vector<Tensor> foreach_pointwise_op(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<Scalar> scalars) {
std::vector<Tensor> foreach_pointwise_op(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<double> scalars) {
std::vector<std::vector<at::Tensor>> tensor_lists;
tensor_lists.reserve(4);
std::vector<at::Tensor> vec_res;
@ -89,13 +89,13 @@ std::vector<Tensor> foreach_pointwise_op(TensorList input, TensorList tensors1,
AT_DISPATCH_ALL_TYPES_AND(kHalf, input[0].scalar_type(), "foreach_pointwise_op_cuda", [&]() {
using opmath_t = get_opmath_t<scalar_t>::opmath_t;
multi_tensor_apply<4, opmath_t>(tensor_lists,
scalars,
PointwiseOpScalarListFunctor<scalar_t,
/* depth */ 4,
/* r_args_depth */ 3,
/* res_arg_index */ 3>(),
Op<opmath_t>());
multi_tensor_apply<4>(tensor_lists,
scalars,
PointwiseOpScalarListFunctor<scalar_t,
/* depth */ 4,
/* r_args_depth */ 3,
/* res_arg_index */ 3>(),
Op<opmath_t>());
});
return tensor_lists[3];
@ -124,7 +124,7 @@ void foreach_tensor_##NAME##_scalar_cuda_(TensorList input, TensorList tensors1,
#define FOREACH_POINTWISE_OP_SCALARLIST(NAME, OP) \
std::vector<Tensor> foreach_tensor_##NAME##_scalarlist_cuda(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<Scalar> scalars) { \
std::vector<Tensor> foreach_tensor_##NAME##_scalarlist_cuda(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<double> scalars) { \
check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \
\
if (!can_use_fast_route(input, tensors1, tensors2, scalars)) { \
@ -134,7 +134,7 @@ std::vector<Tensor> foreach_tensor_##NAME##_scalarlist_cuda(TensorList input, Te
return foreach_pointwise_op<OP>(input, tensors1, tensors2, scalars); \
} \
\
void foreach_tensor_##NAME##_scalarlist_cuda_(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<Scalar> scalars) { \
void foreach_tensor_##NAME##_scalarlist_cuda_(TensorList input, TensorList tensors1, TensorList tensors2, at::ArrayRef<double> scalars) { \
check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \
\
if (!can_use_fast_route(input, tensors1, tensors2, scalars)) { \

View File

@ -56,10 +56,10 @@ multi_tensor_apply_kernel(
callable(kChunkSize, tensorListMeta, args...);
}
template<int depth, typename scalar_T, typename T, typename... ArgTypes>
template<int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(
std::vector<std::vector<at::Tensor>>& tensor_lists,
at::ArrayRef<Scalar> scalars,
at::ArrayRef<double> scalars,
T callable,
ArgTypes... args) {
TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists has to match the depth.");
@ -71,7 +71,7 @@ void multi_tensor_apply(
int loc_tensor_info = 0;
for(size_t t = 0; t < n_tensors; t++) {
tensorListMeta.scalar_vals[loc_tensor_info] = scalars[t].to<scalar_T>();
tensorListMeta.scalar_vals[loc_tensor_info] = scalars[t];
tensorListMeta.numel_for_tensor[loc_tensor_info] = tensor_lists[0][t].numel();
for (int d = 0; d < depth; d++) {
@ -115,6 +115,7 @@ void multi_tensor_apply(
}
}
template<int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(
std::vector<std::vector<at::Tensor>>& tensor_lists,

View File

@ -6713,49 +6713,49 @@
CPU: foreach_tensor_div_list_kernel_slow_
CUDA: foreach_tensor_div_list_kernel_cuda_
- func: _foreach_add.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[]
- func: _foreach_add.ScalarList(Tensor[] tensors, float[] scalars) -> Tensor[]
variants: function
dispatch:
CPU: foreach_tensor_add_scalarlist_kernel_slow
CUDA: foreach_tensor_add_scalarlist_kernel_cuda
- func: _foreach_add_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()
- func: _foreach_add_.ScalarList(Tensor(a!)[] self, float[] scalars) -> ()
variants: function
dispatch:
CPU: foreach_tensor_add_scalarlist_kernel_slow_
CUDA: foreach_tensor_add_scalarlist_kernel_cuda_
- func: _foreach_sub.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[]
- func: _foreach_sub.ScalarList(Tensor[] tensors, float[] scalars) -> Tensor[]
variants: function
dispatch:
CPU: foreach_tensor_sub_scalarlist_kernel_slow
CUDA: foreach_tensor_sub_scalarlist_kernel_cuda
- func: _foreach_sub_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()
- func: _foreach_sub_.ScalarList(Tensor(a!)[] self, float[] scalars) -> ()
variants: function
dispatch:
CPU: foreach_tensor_sub_scalarlist_kernel_slow_
CUDA: foreach_tensor_sub_scalarlist_kernel_cuda_
- func: _foreach_div.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[]
- func: _foreach_div.ScalarList(Tensor[] tensors, float[] scalars) -> Tensor[]
variants: function
dispatch:
CPU: foreach_tensor_div_scalarlist_kernel_slow
CUDA: foreach_tensor_div_scalarlist_kernel_cuda
- func: _foreach_div_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()
- func: _foreach_div_.ScalarList(Tensor(a!)[] self, float[] scalars) -> ()
variants: function
dispatch:
CPU: foreach_tensor_div_scalarlist_kernel_slow_
CUDA: foreach_tensor_div_scalarlist_kernel_cuda_
- func: _foreach_mul.ScalarList(Tensor[] tensors, Scalar[] scalars) -> Tensor[]
- func: _foreach_mul.ScalarList(Tensor[] tensors, float[] scalars) -> Tensor[]
variants: function
dispatch:
CPU: foreach_tensor_mul_scalarlist_kernel_slow
CUDA: foreach_tensor_mul_scalarlist_kernel_cuda
- func: _foreach_mul_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()
- func: _foreach_mul_.ScalarList(Tensor(a!)[] self, float[] scalars) -> ()
variants: function
dispatch:
CPU: foreach_tensor_mul_scalarlist_kernel_slow_
@ -7115,13 +7115,13 @@
CPU: foreach_tensor_addcmul_scalar_slow_
CUDA: foreach_tensor_addcmul_scalar_cuda_
- func: _foreach_addcdiv_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> ()
- func: _foreach_addcdiv_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, float[] scalars) -> ()
variants: function
dispatch:
CPU: foreach_tensor_addcdiv_scalarlist_slow_
CUDA: foreach_tensor_addcdiv_scalarlist_cuda_
- func: _foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> ()
- func: _foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, float[] scalars) -> ()
variants: function
dispatch:
CPU: foreach_tensor_addcmul_scalarlist_slow_
@ -7139,13 +7139,13 @@
CPU: foreach_tensor_addcmul_scalar_slow
CUDA: foreach_tensor_addcmul_scalar_cuda
- func: _foreach_addcdiv.ScalarList(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[]
- func: _foreach_addcdiv.ScalarList(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, float[] scalars) -> Tensor[]
variants: function
dispatch:
CPU: foreach_tensor_addcdiv_scalarlist_slow
CUDA: foreach_tensor_addcdiv_scalarlist_cuda
- func: _foreach_addcmul.ScalarList(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[]
- func: _foreach_addcmul.ScalarList(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, float[] scalars) -> Tensor[]
variants: function
dispatch:
CPU: foreach_tensor_addcmul_scalarlist_slow

View File

@ -59,19 +59,6 @@ allow_list = [
("aten::fake_quantize_per_tensor_affine_backward", datetime.date(2021, 2, 20)),
("aten::fake_quantize_per_channel_affine_backward", datetime.date(2021, 2, 20)),
("aten::rowwise_prune", datetime.date(9999, 1, 1)),
("aten::_foreach_mul_", datetime.date(2021, 3, 2)),
("aten::_foreach_addcdiv_", datetime.date(2021, 3, 2)),
("aten::_foreach_div", datetime.date(2021, 3, 2)),
("aten::_foreach_addcmul_", datetime.date(2021, 3, 2)),
("aten::_foreach_sub", datetime.date(2021, 3, 2)),
("aten::_foreach_add", datetime.date(2021, 3, 2)),
("aten::_foreach_sub_", datetime.date(2021, 3, 2)),
("aten::_foreach_add_", datetime.date(2021, 3, 2)),
("aten::_foreach_mul", datetime.date(2021, 3, 2)),
("aten::_foreach_div_", datetime.date(2021, 3, 2)),
("aten::_foreach_addcdiv", datetime.date(2021, 3, 2)),
("aten::_foreach_addcmul", datetime.date(2021, 3, 2)),
]
def allow_listed(schema, allow_list):

View File

@ -7,11 +7,25 @@ from torch._six import inf, nan
N_values = [20] if not TEST_WITH_SLOW else [30, 300]
class TestForeach(TestCase):
bin_ops = [
(torch._foreach_add, torch._foreach_add_, torch.add),
(torch._foreach_sub, torch._foreach_sub_, torch.sub),
(torch._foreach_mul, torch._foreach_mul_, torch.mul),
(torch._foreach_div, torch._foreach_div_, torch.div),
foreach_bin_ops = [
torch._foreach_add,
torch._foreach_sub,
torch._foreach_mul,
torch._foreach_div,
]
foreach_bin_ops_ = [
torch._foreach_add_,
torch._foreach_sub_,
torch._foreach_mul_,
torch._foreach_div_,
]
torch_bin_ops = [
torch.add,
torch.sub,
torch.mul,
torch.div,
]
unary_ops = [
@ -128,7 +142,7 @@ class TestForeach(TestCase):
op(tensors, tensors1, tensors2, [2 for _ in range(N)])
def _test_bin_op_list_alpha(self, device, dtype, foreach_op, foreach_op_, torch_op):
for N in N_values:
for N in [30, 300]:
tensors1 = self._get_test_data(device, dtype, N)
tensors2 = self._get_test_data(device, dtype, N)
alpha = 2
@ -338,7 +352,9 @@ class TestForeach(TestCase):
@dtypes(*torch.testing.get_all_dtypes())
def test_int_scalar(self, device, dtype):
for N in N_values:
for foreach_bin_op, foreach_bin_op_, torch_bin_op in self.bin_ops:
for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
self.foreach_bin_ops_,
self.torch_bin_ops):
tensors = self._get_test_data(device, dtype, N)
scalar = 3
expected = [torch_bin_op(t, scalar) for t in tensors]
@ -378,13 +394,15 @@ class TestForeach(TestCase):
@dtypes(*torch.testing.get_all_dtypes())
def test_int_scalarlist(self, device, dtype):
for N in N_values:
for foreach_bin_op, foreach_bin_op_, torch_bin_op in self.bin_ops:
for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
self.foreach_bin_ops_,
self.torch_bin_ops):
tensors = self._get_test_data(device, dtype, N)
scalars = [1 for _ in range(N)]
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
# we dont support bool and complex types on CUDA for now
if dtype in torch.testing.get_all_complex_dtypes() and self.device_type == 'cuda':
if (dtype in torch.testing.get_all_complex_dtypes() or dtype == torch.bool) and self.device_type == 'cuda':
with self.assertRaisesRegex(RuntimeError, "not implemented for"):
foreach_bin_op_(tensors, scalars)
@ -395,41 +413,36 @@ class TestForeach(TestCase):
res = foreach_bin_op(tensors, scalars)
if dtype == torch.bool:
self.assertEqual(res, [torch_bin_op(t, s) for t, s in zip(tensors, scalars)])
self.assertEqual(res, [torch_bin_op(t.to(torch.float32), s) for t, s in zip(tensors, scalars)])
with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"):
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired output type"):
foreach_bin_op_(tensors, scalars)
return
# test out of place
if dtype in torch.testing.integral_types():
if self.device_type == 'cpu':
self.assertEqual(res, expected)
self.assertEqual(res, [e.to(torch.float32) for e in expected])
else:
# TODO[type promotion]: Fix once type promotion is enabled.
self.assertEqual(res, [e.to(dtype) for e in expected])
else:
self.assertEqual(res, expected)
# test in-place
if dtype in torch.testing.floating_types() and self.device_type == 'cpu':
foreach_bin_op_(tensors, scalars)
if dtype in torch.testing.integral_types() and self.device_type == 'cpu':
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired output type"):
foreach_bin_op_(tensors, scalars)
return
else:
if foreach_bin_op_ == torch._foreach_div_ and \
dtype in torch.testing.integral_types() and \
self.device_type == 'cpu':
with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"):
foreach_bin_op_(tensors, scalars)
else:
foreach_bin_op_(tensors, scalars)
self.assertEqual(res, tensors)
foreach_bin_op_(tensors, scalars)
self.assertEqual(res, tensors)
@skipCUDAIfRocm
@dtypes(*torch.testing.get_all_dtypes())
def test_float_scalar(self, device, dtype):
for N in N_values:
for foreach_bin_op, foreach_bin_op_, torch_bin_op in self.bin_ops:
for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
self.foreach_bin_ops_,
self.torch_bin_ops):
tensors = self._get_test_data(device, dtype, N)
scalar = 3.3
@ -471,38 +484,17 @@ class TestForeach(TestCase):
@dtypes(*torch.testing.get_all_dtypes())
def test_float_scalarlist(self, device, dtype):
for N in N_values:
for foreach_bin_op, foreach_bin_op_, torch_bin_op in self.bin_ops:
for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
self.foreach_bin_ops_,
self.torch_bin_ops):
tensors = self._get_test_data(device, dtype, N)
scalars = [1.1 for _ in range(N)]
# Bool case
if dtype == torch.bool:
if foreach_bin_op == torch._foreach_sub:
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with a bool tensor"):
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with a bool tensor"):
res = foreach_bin_op(tensors, scalars)
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with a bool tensor"):
[t.sub_(scalar) for t, scalar in zip(tensors, scalars)]
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with a bool tensor"):
foreach_bin_op_(tensors, scalars)
continue
res = foreach_bin_op(tensors, scalars)
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
self.assertEqual(res, expected)
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired output type"):
foreach_bin_op_(tensors, scalars)
continue
# If incoming dtype is float16 or bfloat16, runs in float32 and casts output back to dtype.
control_dtype = torch.float32 if (self.device_type == 'cuda' and
(dtype is torch.float16 or dtype is torch.bfloat16)) else dtype
expected = [torch_bin_op(t.to(dtype=control_dtype), s) for t, s in zip(tensors, scalars)]
expected = [torch_bin_op(t.to(dtype=control_dtype),
s) for t, s in zip(tensors, scalars)]
if (dtype is torch.float16 or dtype is torch.bfloat16):
expected = [e.to(dtype=dtype) for e in expected]
@ -517,11 +509,21 @@ class TestForeach(TestCase):
res = foreach_bin_op(tensors, scalars)
if dtype in torch.testing.integral_types() and self.device_type == 'cuda':
self.assertEqual(res, expected)
if dtype == torch.bool:
# see TODO[Fix scalar list]
self.assertEqual(res, [torch_bin_op(t.to(torch.float32), s) for t, s in zip(tensors, scalars)])
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired output type"):
foreach_bin_op_(tensors, scalars)
continue
return
if dtype in torch.testing.integral_types() and self.device_type == 'cuda':
# see TODO[Fix scalar list]
self.assertEqual(res, [e.to(dtype) for e in expected])
foreach_bin_op_(tensors, scalars)
self.assertEqual(tensors, res)
return
else:
if (dtype is torch.float16 or dtype is torch.bfloat16) and TEST_WITH_ROCM:
self.assertEqual(res, expected, atol=1.e-3, rtol=self.dtype_precisions[dtype][0])
@ -531,7 +533,7 @@ class TestForeach(TestCase):
if dtype in torch.testing.integral_types() and self.device_type == "cpu":
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired output type"):
foreach_bin_op_(tensors, scalars)
continue
return
foreach_bin_op_(tensors, scalars)
if (dtype is torch.float16 or dtype is torch.bfloat16) and TEST_WITH_ROCM:
@ -543,29 +545,34 @@ class TestForeach(TestCase):
@dtypes(*torch.testing.get_all_dtypes())
def test_complex_scalar(self, device, dtype):
for N in N_values:
for foreach_bin_op, foreach_bin_op_, torch_bin_op in self.bin_ops:
for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
self.foreach_bin_ops_,
self.torch_bin_ops):
tensors = self._get_test_data(device, dtype, N)
scalar = 3 + 5j
expected = [torch_bin_op(t, scalar) for t in tensors]
# Bool case
if dtype == torch.bool:
if foreach_bin_op == torch._foreach_sub:
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator,"):
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with two bool"):
foreach_bin_op_(tensors, scalar)
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator,"):
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with two bool"):
foreach_bin_op(tensors, scalar)
continue
return
if dtype in torch.testing.get_all_fp_dtypes(include_half=True, include_bfloat16=True) and \
self.device_type == 'cuda':
with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"):
foreach_bin_op_(tensors, scalar)
with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"):
foreach_bin_op(tensors, scalar)
return
res = foreach_bin_op(tensors, scalar)
expected = [torch_bin_op(t, scalar) for t in tensors]
self.assertEqual(res, expected)
if dtype in torch.testing.get_all_fp_dtypes() and self.device_type == 'cuda':
with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"):
foreach_bin_op_(tensors, scalar)
continue
if dtype not in [torch.complex64, torch.complex128]:
with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"):
foreach_bin_op_(tensors, scalar)
@ -573,48 +580,38 @@ class TestForeach(TestCase):
foreach_bin_op_(tensors, scalar)
self.assertEqual(res, tensors)
@skipCUDAIfRocm
@dtypes(*torch.testing.get_all_dtypes())
def test_complex_scalarlist(self, device, dtype):
for N in N_values:
for foreach_bin_op, foreach_bin_op_, torch_bin_op in self.bin_ops:
for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
self.foreach_bin_ops_,
self.torch_bin_ops):
tensors = self._get_test_data(device, dtype, N)
scalars = [3 + 5j for _ in range(N)]
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
# Bool case
if dtype == torch.bool:
if foreach_bin_op == torch._foreach_sub:
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator"):
foreach_bin_op_(tensors, scalars)
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with two bool"):
foreach_bin_op_(tensors, scalar)
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator"):
foreach_bin_op(tensors, scalars)
continue
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with two bool"):
foreach_bin_op(tensors, scalar)
return
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
if dtype in [torch.complex64, torch.complex128] and self.device_type == "cuda":
with self.assertRaisesRegex(RuntimeError, "not implemented for"):
res = foreach_bin_op(tensors, scalars)
with self.assertRaisesRegex(RuntimeError, "not implemented for"):
foreach_bin_op_(tensors, scalars)
continue
else:
with self.assertRaisesRegex(TypeError, "argument 'scalars' must be tuple of floats"):
res = foreach_bin_op(tensors, scalars)
self.assertEqual(res, expected)
if dtype not in [torch.complex64, torch.complex128]:
with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"):
foreach_bin_op_(tensors, scalars)
else:
with self.assertRaisesRegex(TypeError, "argument 'scalars' must be tuple of floats"):
foreach_bin_op_(tensors, scalars)
self.assertEqual(res, tensors)
@skipCUDAIfRocm
@dtypes(*torch.testing.get_all_dtypes())
def test_bool_scalar(self, device, dtype):
for N in N_values:
for foreach_bin_op, foreach_bin_op_, torch_bin_op in self.bin_ops:
for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
self.foreach_bin_ops_,
self.torch_bin_ops):
tensors = self._get_test_data(device, dtype, N)
scalar = True
@ -663,7 +660,9 @@ class TestForeach(TestCase):
@dtypes(*torch.testing.get_all_dtypes())
def test_bool_scalarlist(self, device, dtype):
for N in N_values:
for foreach_bin_op, foreach_bin_op_, torch_bin_op in self.bin_ops:
for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
self.foreach_bin_ops_,
self.torch_bin_ops):
tensors = self._get_test_data(device, dtype, N)
scalars = [True for _ in range(N)]
@ -677,24 +676,20 @@ class TestForeach(TestCase):
return
else:
if foreach_bin_op == torch._foreach_sub:
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with two bool tensors"):
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with a bool tensor"):
foreach_bin_op_(tensors, scalars)
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with two bool tensors"):
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with a bool tensor"):
foreach_bin_op(tensors, scalars)
else:
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
res = foreach_bin_op(tensors, scalars)
self.assertEqual(res, expected)
if foreach_bin_op_ == torch._foreach_div_:
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired"):
foreach_bin_op_(tensors, scalars)
else:
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired"):
foreach_bin_op_(tensors, scalars)
self.assertEqual(res, tensors)
res = foreach_bin_op(tensors, scalars)
for r in res:
self.assertTrue(r.dtype == torch.float32)
else:
# we dont support complex types on CUDA for now
# we dont support bool and complex types on CUDA for now
if (dtype in torch.testing.get_all_complex_dtypes()) and self.device_type == 'cuda':
with self.assertRaisesRegex(RuntimeError, "not implemented for"):
foreach_bin_op_(tensors, scalars)
@ -703,44 +698,57 @@ class TestForeach(TestCase):
foreach_bin_op(tensors, scalars)
return
if foreach_bin_op == torch._foreach_sub and self.device_type == "cpu":
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with a bool tensor"):
foreach_bin_op_(tensors, scalars)
if foreach_bin_op == torch._foreach_sub:
if self.device_type == "cpu":
# see TODO[Fix scalar list]
res = foreach_bin_op(tensors, scalars)
if dtype in torch.testing.integral_types():
self.assertEqual(res, [r.to(torch.float32) for r in [torch_bin_op(t, 1) for t in tensors]])
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with a bool tensor"):
foreach_bin_op(tensors, scalars)
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the "):
foreach_bin_op_(tensors, scalars)
else:
self.assertEqual(res, [torch_bin_op(t, 1) for t in tensors])
foreach_bin_op_(tensors, scalars)
self.assertEqual(res, tensors)
else:
# see TODO[Fix scalar list]
res = foreach_bin_op(tensors, scalars)
if dtype in torch.testing.integral_types():
self.assertEqual(res, [r.to(dtype) for r in [torch_bin_op(t, 1) for t in tensors]])
else:
self.assertEqual(res, [torch_bin_op(t, 1) for t in tensors])
foreach_bin_op_(tensors, scalars)
self.assertEqual(res, tensors)
else:
if self.device_type == "cpu":
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
res = foreach_bin_op(tensors, scalars)
self.assertEqual(res, expected)
# see TODO[Fix scalar list]
if dtype in torch.testing.integral_types():
self.assertEqual(res, [e.to(torch.float32) for e in expected])
else:
self.assertEqual(res, expected)
if dtype in torch.testing.integral_types() and foreach_bin_op_ == torch._foreach_div_:
if dtype in torch.testing.integral_types():
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired "):
foreach_bin_op_(tensors, scalars)
else:
foreach_bin_op_(tensors, scalars)
self.assertEqual(tensors, expected)
else:
if foreach_bin_op == torch._foreach_sub:
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with a bo"):
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
res = foreach_bin_op(tensors, scalars)
res = foreach_bin_op(tensors, scalars)
foreach_bin_op_(tensors, scalars)
self.assertEqual(res, tensors)
if dtype in torch.testing.integral_types():
self.assertEqual(res, [e.to(dtype) for e in expected])
else:
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
res = foreach_bin_op(tensors, scalars)
self.assertEqual(res, expected)
if dtype in torch.testing.integral_types():
self.assertEqual(res, [e.to(dtype) for e in expected])
else:
self.assertEqual(res, expected)
foreach_bin_op_(tensors, scalars)
self.assertEqual(res, tensors)
foreach_bin_op_(tensors, scalars)
self.assertEqual(res, tensors)
@dtypes(*torch.testing.get_all_dtypes())
def test_add_with_different_size_tensors(self, device, dtype):
@ -783,7 +791,7 @@ class TestForeach(TestCase):
# Ops with list
#
def test_bin_op_list_error_cases(self, device):
for bin_op, bin_op_, _ in self.bin_ops:
for bin_op, bin_op_ in zip(self.foreach_bin_ops, self.foreach_bin_ops_):
tensors1 = []
tensors2 = []