From 6a368b3fc59a61071472d4472fc8ca298e184934 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Tue, 12 Nov 2024 19:03:38 +0000 Subject: [PATCH] Add ScalarList overload to `_foreach_lerp` (#134482) Related: - https://github.com/pytorch/pytorch/issues/133367 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134482 Approved by: https://github.com/janeyx99 --- aten/src/ATen/native/ForeachOpsKernels.cpp | 59 ++++++++++----- aten/src/ATen/native/ForeachUtils.h | 13 ++++ aten/src/ATen/native/cuda/ForeachFunctors.cuh | 57 +++++++++++++++ aten/src/ATen/native/cuda/ForeachTernaryOp.cu | 71 +++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 16 +++++ ...asDecompTest.test_has_decomposition.expect | 3 + test/test_foreach.py | 2 + torch/_inductor/decomposition.py | 14 ++++ torch/optim/_adafactor.py | 12 +--- .../_internal/common_methods_invocations.py | 28 +++++++- 10 files changed, 245 insertions(+), 30 deletions(-) diff --git a/aten/src/ATen/native/ForeachOpsKernels.cpp b/aten/src/ATen/native/ForeachOpsKernels.cpp index bf88cc50921..64c39fcaef2 100644 --- a/aten/src/ATen/native/ForeachOpsKernels.cpp +++ b/aten/src/ATen/native/ForeachOpsKernels.cpp @@ -411,26 +411,49 @@ FOREACH_POINTWISE_OP_SCALARLIST(addcmul) FOREACH_POINTWISE_OP_TENSOR(addcdiv) FOREACH_POINTWISE_OP_TENSOR(addcmul) -#define FOREACH_TERNARY_OP(OP) \ - std::vector foreach_tensor_ternary_##OP##_slow( \ - TensorList tensors1, TensorList tensors2, TensorList tensors3) { \ - check_foreach_api_restrictions(tensors1, tensors2, tensors3); \ - std::vector result; \ - for (const auto i : c10::irange(tensors1.size())) { \ - result.emplace_back(tensors1[i].OP(tensors2[i], tensors3[i])); \ - } \ - return result; \ - } \ - \ - void foreach_tensor_ternary_##OP##_slow_( \ - TensorList tensors1, TensorList tensors2, TensorList tensors3) { \ - check_foreach_api_restrictions(tensors1, tensors2, tensors3); \ - for (const auto i : c10::irange(tensors1.size())) { \ - tensors1[i].OP##_(tensors2[i], tensors3[i]); \ - } \ +std::vector foreach_tensor_ternary_lerp_slow( + TensorList tensors1, + TensorList tensors2, + TensorList tensors3) { + check_foreach_api_restrictions(tensors1, tensors2, tensors3); + std::vector result; + for (const auto i : c10::irange(tensors1.size())) { + result.emplace_back(tensors1[i].lerp(tensors2[i], tensors3[i])); } + return result; +} -FOREACH_TERNARY_OP(lerp) +void foreach_tensor_ternary_lerp_slow_( + TensorList tensors1, + TensorList tensors2, + TensorList tensors3) { + check_foreach_api_restrictions(tensors1, tensors2, tensors3); + for (const auto i : c10::irange(tensors1.size())) { + tensors1[i].lerp_(tensors2[i], tensors3[i]); + } +} + +std::vector foreach_tensor_lerp_scalarlist_kernel_slow( + TensorList tensors1, + TensorList tensors2, + at::ArrayRef scalars) { + check_foreach_api_restrictions(tensors1, tensors2, scalars); + std::vector result; + for (const auto i : c10::irange(tensors1.size())) { + result.emplace_back(tensors1[i].lerp(tensors2[i], scalars[i])); + } + return result; +} + +void foreach_tensor_lerp_scalarlist_kernel_slow_( + TensorList tensors1, + TensorList tensors2, + at::ArrayRef scalars) { + check_foreach_api_restrictions(tensors1, tensors2, scalars); + for (const auto i : c10::irange(tensors1.size())) { + tensors1[i].lerp_(tensors2[i], scalars[i]); + } +} void foreach_tensor_zero_slow_(TensorList tensors) { check_foreach_api_restrictions(tensors); diff --git a/aten/src/ATen/native/ForeachUtils.h b/aten/src/ATen/native/ForeachUtils.h index a8fbe13b8da..56b7a6f98e7 100644 --- a/aten/src/ATen/native/ForeachUtils.h +++ b/aten/src/ATen/native/ForeachUtils.h @@ -98,6 +98,19 @@ inline void check_foreach_api_restrictions( scalars.size()); } +inline void check_foreach_api_restrictions( + TensorList tensors1, + TensorList tensors2, + ArrayRef scalars) { + check_foreach_api_restrictions(tensors1, tensors2); + TORCH_CHECK( + tensors1.size() == scalars.size(), + "Tensor list must have same number of elements as scalar list, got ", + tensors1.size(), + " and ", + scalars.size()); +} + // Helper function called in check_fast_path_restrictions to check whether all // corresponding tensors (aligning in index across the tensorLists) share the // same device and dtype. diff --git a/aten/src/ATen/native/cuda/ForeachFunctors.cuh b/aten/src/ATen/native/cuda/ForeachFunctors.cuh index 55e4fd7a598..645b095c5a6 100644 --- a/aten/src/ATen/native/cuda/ForeachFunctors.cuh +++ b/aten/src/ATen/native/cuda/ForeachFunctors.cuh @@ -663,6 +663,63 @@ struct TernaryOpScalarFunctor { } }; +template +struct TernaryOpScalarListFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListScalarListMetadata& tl, + Op op) { + static_assert(depth == 2 || depth == 3, ""); + static_assert(depth >= r_args_depth, ""); + static_assert(res_arg_index == depth - 1 || res_arg_index == 0, ""); + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + const opmath_t scalar = tl.scalar_vals[tensor_loc]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]), + scalar); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]), + scalar); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + template struct power_functor { C10_DEVICE T operator()(const T& a, const T& b) const { diff --git a/aten/src/ATen/native/cuda/ForeachTernaryOp.cu b/aten/src/ATen/native/cuda/ForeachTernaryOp.cu index e13f2015f1d..a6599287f3d 100644 --- a/aten/src/ATen/native/cuda/ForeachTernaryOp.cu +++ b/aten/src/ATen/native/cuda/ForeachTernaryOp.cu @@ -156,4 +156,75 @@ void foreach_tensor_lerp_list_cuda_( weight.to()); }); } + +std::vector foreach_tensor_lerp_scalarlist_cuda( + TensorList tensors1, + TensorList tensors2, + at::ArrayRef scalars) { + check_foreach_api_restrictions(tensors1, tensors2, scalars); + if (!can_use_fast_route({tensors1, tensors2}, scalars, true)) { + return foreach_tensor_lerp_scalarlist_kernel_slow( + tensors1, tensors2, scalars); + } + + std::vector vec_res; + vec_res.reserve(tensors1.size()); + for (const auto& t : tensors1) { + vec_res.emplace_back(at::native::empty_like(t)); + } + std::vector> tensor_lists{ + tensors1.vec(), tensors2.vec(), vec_res}; + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + tensors1[0].scalar_type(), + "foreach_tensor_lerp_scalarlist_cuda", + [&]() { + using opmath_t = typename at::opmath_type; + multi_tensor_apply<3, opmath_t>( + tensor_lists, + scalars, + TernaryOpScalarListFunctor< + scalar_t, + /* depth */ 3, + /* r_args_depth */ 2, + /* res_arg_index */ 2>(), + LerpFunctor()); + }); + + return tensor_lists[2]; +} + +void foreach_tensor_lerp_scalarlist_cuda_( + TensorList tensors1, + TensorList tensors2, + at::ArrayRef scalars) { + check_foreach_api_restrictions(tensors1, tensors2, scalars); + if (!can_use_fast_route({tensors1, tensors2}, scalars, true)) { + return foreach_tensor_lerp_scalarlist_kernel_slow_( + tensors1, tensors2, scalars); + } + + std::vector> tensor_lists{ + tensors1.vec(), tensors2.vec()}; + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + tensors1[0].scalar_type(), + "foreach_tensor_lerp_scalarlist_cuda_", + [&]() { + using opmath_t = typename at::opmath_type; + multi_tensor_apply<2, opmath_t>( + tensor_lists, + scalars, + TernaryOpScalarListFunctor< + scalar_t, + /* depth */ 2, + /* r_args_depth */ 2, + /* res_arg_index */ 0>(), + LerpFunctor()); + }); +} } // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 97e36a2e289..cc9249dac43 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -11105,6 +11105,22 @@ CUDA: foreach_tensor_lerp_list_cuda_ autogen: _foreach_lerp.Scalar_out +- func: _foreach_lerp.ScalarList(Tensor[] self, Tensor[] tensors1, Scalar[] weight) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensors are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_lerp_scalarlist_kernel_slow + CUDA: foreach_tensor_lerp_scalarlist_cuda + autogen: _foreach_lerp.ScalarList_out + +- func: _foreach_lerp_.ScalarList(Tensor(a!)[] self, Tensor[] tensors1, Scalar[] weight) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensors are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_lerp_scalarlist_kernel_slow_ + CUDA: foreach_tensor_lerp_scalarlist_cuda_ + autogen: _foreach_lerp.ScalarList_out + - func: _foreach_lgamma(Tensor[] self) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 9cc72c550e1..98bed17ebd3 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -232,9 +232,12 @@ aten::_foreach_frac_ aten::_foreach_lerp.List aten::_foreach_lerp.List_out aten::_foreach_lerp.Scalar +aten::_foreach_lerp.ScalarList +aten::_foreach_lerp.ScalarList_out aten::_foreach_lerp.Scalar_out aten::_foreach_lerp_.List aten::_foreach_lerp_.Scalar +aten::_foreach_lerp_.ScalarList aten::_foreach_lgamma aten::_foreach_lgamma.out aten::_foreach_lgamma_ diff --git a/test/test_foreach.py b/test/test_foreach.py index 4ee8e24b538..9cfc98c1c80 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -1534,6 +1534,8 @@ def check_autodiff_sample(op, sample, dtype, is_inplace): or (isinstance(sample.args[-1], complex)) ) if rhs_arg_has_complex_number and dtype == torch.float64: + if op.name == "_foreach_lerp": + return False, "value cannot be converted to type double without overflow" if op.name in ( "_foreach_clamp_max", "_foreach_clamp_min", diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 3f90d40810b..ec8821bacae 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -749,6 +749,20 @@ def _foreach_lerp_scalar( ) +@register_decomposition(aten._foreach_lerp.ScalarList) +def _foreach_lerp_scalarlist( + start_tensors: List[torch.Tensor], + end_tensors: List[torch.Tensor], + scalars: List[torch.types.Number], +) -> List[torch.Tensor]: + return aten._foreach_add.List( + start_tensors, + aten._foreach_mul.ScalarList( + aten._foreach_sub.List(end_tensors, start_tensors), scalars + ), + ) + + @aten.miopen_batch_norm.default.py_impl(torch._C.DispatchKey.Autograd) @register_decomposition(aten.miopen_batch_norm) def miopen_batch_norm( diff --git a/torch/optim/_adafactor.py b/torch/optim/_adafactor.py index 4c5dbdd926e..bc58180ed03 100644 --- a/torch/optim/_adafactor.py +++ b/torch/optim/_adafactor.py @@ -541,9 +541,7 @@ def _multi_tensor_adafactor( ] torch._foreach_mul_(row_means, row_means) torch._foreach_div_(row_means, [grad.size(-1) for grad in device_grads]) - torch._foreach_mul_(device_row_vars, beta2_ts) - torch._foreach_mul_(row_means, one_minus_beta2_ts) - torch._foreach_add_(device_row_vars, row_means) + torch._foreach_lerp_(device_row_vars, row_means, one_minus_beta2_ts) del row_means # same as (g * g).mean(dim=-2) w/o materializing an intermediate size g @@ -552,9 +550,7 @@ def _multi_tensor_adafactor( ] torch._foreach_mul_(col_means, col_means) torch._foreach_div_(col_means, [grad.size(-2) for grad in device_grads]) - torch._foreach_mul_(device_col_vars, beta2_ts) - torch._foreach_mul_(col_means, one_minus_beta2_ts) - torch._foreach_add_(device_col_vars, col_means) + torch._foreach_lerp_(device_col_vars, col_means, one_minus_beta2_ts) del col_means var_estimates = [ @@ -574,9 +570,7 @@ def _multi_tensor_adafactor( ), "variance should be defined when grad is a vector" grads_squared = torch._foreach_mul(device_grads, device_grads) - torch._foreach_mul_(device_variances, beta2_ts) - torch._foreach_mul_(grads_squared, one_minus_beta2_ts) - torch._foreach_add_(device_variances, grads_squared) + torch._foreach_lerp_(device_variances, grads_squared, one_minus_beta2_ts) del grads_squared # avoid writing into variance during update diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 13ea0bc9578..5b4bb626173 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -11287,7 +11287,7 @@ foreach_reduce_op_db: List[ForeachFuncInfo] = [ foreach_other_op_db: List[ForeachFuncInfo] = [ ForeachFuncInfo( "lerp", - sample_inputs_func=foreach_inputs_sample_func(3, True, False), + sample_inputs_func=foreach_inputs_sample_func(3, True, True), dtypesIfHpu=custom_types(torch.float32, torch.bfloat16), supports_autograd=True, supports_inplace_autograd=True, @@ -11317,8 +11317,30 @@ foreach_other_op_db: List[ForeachFuncInfo] = [ "test_dispatch_symbolic_meta_inplace", dtypes=integral_types_and(torch.bool), ), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=integral_types_and(torch.bool)), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=integral_types_and(torch.bool)), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_inplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_meta_outplace", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=integral_types_and(torch.bool), + ), + DecorateInfo( + unittest.expectedFailure, + "TestMeta", + "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=integral_types_and(torch.bool), + ), ), ), ]