[ROCm] Extend vectorized elementwise kernel to more heterogenous tensor types. (#149738)

This patch extends the initial support for "vectorized templated" kernels to the following input tensor types: (BFloat16, float)
(float, float16)
(float16, float)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149738
Approved by: https://github.com/jeffdaily
This commit is contained in:
Carlo Bertolli 2025-03-25 01:09:58 +00:00 committed by PyTorch MergeBot
parent 2a9e737839
commit 5af9cb12b7
2 changed files with 156 additions and 47 deletions

View File

@ -292,8 +292,11 @@ __global__ void vectorized_templated_elementwise_kernel(
out_calc_t out_calc,
loader_t loader,
storer_t storer) {
int remaining =
N - vectorized_templated_config::block_work_size() * blockIdx.x;
int remaining = N -
vectorized_templated_config::block_work_size() *
(gridDim.x - blockIdx.x - 1);
constexpr bool reverted_idx = true;
if (remaining <
vectorized_templated_config::block_work_size()) { // if this block handles
// the reminder,
@ -307,18 +310,17 @@ __global__ void vectorized_templated_elementwise_kernel(
storer_t,
vectorized_templated_config::elems_per_thread()>(
data, remaining, inp_calc, out_calc, loader, storer);
elementwise_kernel_helper(f, policy);
elementwise_kernel_helper<reverted_idx>(f, policy);
} else { // if this block has a full `block_work_size` data to handle, use
// vectorized memory access
elementwise_kernel_helper(
f,
memory::policies::vectorized_templated<
vec_size,
array_t,
vectorized_templated_config::elems_per_thread(),
vectorized_templated_config::num_threads(),
OutputType,
InputTypes...>(data));
auto policy = memory::policies::vectorized_templated<
vec_size,
array_t,
vectorized_templated_config::elems_per_thread(),
vectorized_templated_config::num_threads(),
OutputType,
InputTypes...>(data);
elementwise_kernel_helper<reverted_idx>(f, policy);
}
}
@ -544,19 +546,34 @@ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
#ifdef USE_ROCM
namespace {
template <typename TupleLike, size_t arity, size_t arg_num = 0>
struct check_types {
template <
typename TupleLike,
typename FirstParamTy,
typename SecondParamTy,
size_t arity,
size_t arg_num = 0>
struct check_binary_functor_types_for_specialization {
constexpr static inline bool check() {
if constexpr (arity != 2)
return false;
if constexpr (arg_num == 0) {
using SelectedType = std::tuple_element_t<arg_num, TupleLike>;
if constexpr (std::is_same_v<float, SelectedType>)
return check_types<TupleLike, arity, arg_num + 1>::check();
if constexpr (std::is_same_v<FirstParamTy, SelectedType>)
return check_binary_functor_types_for_specialization<
TupleLike,
FirstParamTy,
SecondParamTy,
arity,
arg_num + 1>::check();
} else if constexpr (arg_num == 1) {
using SelectedType2 = std::tuple_element_t<arg_num, TupleLike>;
if constexpr (std::is_same_v<float, SelectedType2>)
return check_types<TupleLike, arity, arg_num + 1>::check();
if constexpr (std::is_same_v<SecondParamTy, SelectedType2>)
return check_binary_functor_types_for_specialization<
TupleLike,
FirstParamTy,
SecondParamTy,
arity,
arg_num + 1>::check();
}
return false;
}
@ -564,8 +581,17 @@ struct check_types {
// Bottom case: if we got this far, assume correct type matching except
// when there are no arguments (arity == 0).
template <typename TupleLike, size_t arity>
struct check_types<TupleLike, arity, arity> {
template <
typename TupleLike,
typename FirstParamTy,
typename SecondParamTy,
size_t arity>
struct check_binary_functor_types_for_specialization<
TupleLike,
FirstParamTy,
SecondParamTy,
arity,
arity> {
constexpr static inline bool check() {
if constexpr (arity != 0)
return true;
@ -573,12 +599,90 @@ struct check_types<TupleLike, arity, arity> {
}
};
template <typename TupleLike>
struct check_types<TupleLike, 0, 0> {
template <typename TupleLike, typename FirstParamTy, typename SecondParamTy>
struct check_binary_functor_types_for_specialization<
TupleLike,
FirstParamTy,
SecondParamTy,
0,
0> {
constexpr static inline bool check() {
return false;
}
};
// The following is a list of type specializations for vectorized_templated
// elementwise kernel. It refers to the first and second runtime types of the
// arguments of a binary functor.
constexpr std::array rt_binary_specializations = {
std::array<c10::ScalarType, 2>(
{c10::CppTypeToScalarType<float>::value,
c10::CppTypeToScalarType<BFloat16>::value}),
std::array<c10::ScalarType, 2>(
{c10::CppTypeToScalarType<BFloat16>::value,
c10::CppTypeToScalarType<float>::value}),
std::array<c10::ScalarType, 2>(
{c10::CppTypeToScalarType<float>::value,
c10::CppTypeToScalarType<Half>::value}),
std::array<c10::ScalarType, 2>(
{c10::CppTypeToScalarType<Half>::value,
c10::CppTypeToScalarType<float>::value})};
bool check_binary_rt_types_for_specialization(TensorIteratorBase& iter) {
if (iter.ninputs() != 2)
return false;
for (auto spec : rt_binary_specializations)
if (iter.input_dtype(0) == spec[0] && iter.input_dtype(1) == spec[1])
return true;
return false;
}
template <int arg_index>
struct type_specialized_kernel_launcher {
template <
typename func_t,
typename array_t,
typename inp_calc_t,
typename out_calc_t,
typename loader_t,
typename storer_t>
static void apply(
ScalarType arg0_t,
ScalarType arg1_t,
int64_t numel,
func_t f,
array_t data,
inp_calc_t input_offset_calculator,
out_calc_t output_offset_calculator,
loader_t loader,
storer_t storer) {
using traits = function_traits<func_t>;
using return_t = typename traits::result_type;
if (arg0_t == rt_binary_specializations[arg_index][0] &&
arg1_t == rt_binary_specializations[arg_index][1])
launch_vectorized_templated_kernel<
func_t,
array_t,
inp_calc_t,
out_calc_t,
loader_t,
storer_t,
return_t,
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][0]>::t),
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][1]>::t)>(
numel,
f,
data,
input_offset_calculator,
output_offset_calculator,
loader,
storer);
}
};
} // namespace
#endif
@ -608,43 +712,46 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
#ifdef USE_ROCM
// Attempt to call specialized vectorized elementwise kernel
// that enables interleaving.
using float_map = c10::CppTypeToScalarType<float>;
using bfloat16_map = c10::CppTypeToScalarType<BFloat16>;
if (iter.ninputs() == 2 && iter.input_dtype(0) == float_map::value &&
iter.input_dtype(1) == bfloat16_map::value &&
if (check_binary_rt_types_for_specialization(iter) &&
memory::can_vectorize_up_to<func_t>(data) > 1) {
// constexpr to reduce the amount of kernels (empty) generated for
// constexpr to reduce the amount of kernels generated for
// vectorized templated elementwise and limit which functors are actually
// applied to the load and store at compile time.
using func_tuple = typename traits::ArgsTuple;
if constexpr (
std::is_same_v<float, arg0_t> && traits::arity == 2 &&
check_types<func_tuple, traits::arity, 0>::check()) {
check_binary_functor_types_for_specialization<
func_tuple,
float,
float,
traits::arity,
/*arg_num=*/0>::check()) {
// If we got here, we know we are in one of the specialized cases. We
// need to translate the runtime type to a statically known type. This
// is effectively hoisting to the host the switch over runtime type in
// the kernel in fetch_and_cast. Loader, storer, offset calculators are
// only needed for the reminder loop.
auto input_offset_calculator = TrivialOffsetCalculator<traits::arity>();
auto output_offset_calculator = TrivialOffsetCalculator<1>();
auto loader = memory::LoadWithCast<traits::arity>(iter);
auto storer = memory::StoreWithCast<1>(iter);
launch_vectorized_templated_kernel<
func_t,
std::array<char*, ntensors>,
decltype(input_offset_calculator),
decltype(output_offset_calculator),
decltype(loader),
decltype(storer),
float,
float,
BFloat16>(
numel,
f,
data,
input_offset_calculator,
output_offset_calculator,
loader,
storer);
memory::detail::static_unroll<
type_specialized_kernel_launcher,
rt_binary_specializations.size()>::
with_args(
iter.input_dtype(0),
iter.input_dtype(1),
numel,
f,
data,
input_offset_calculator,
output_offset_calculator,
loader,
storer);
return;
}
}
std::array<ScalarType, ntensors> dtypes;
auto inner_strides = iter.get_inner_strides();
std::array<int, ntensors> strides;

View File

@ -41,7 +41,7 @@ static OffsetCalculator<num_outputs> make_output_offset_calculator(const TensorI
return OffsetCalculator<num_outputs>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
}
template<typename func_t, typename policy_t>
template <bool reverted_idx = false, typename func_t, typename policy_t>
__device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
using traits = function_traits<func_t>;
using return_t = typename traits::result_type;
@ -49,6 +49,8 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
constexpr int elems_per_thread = policy_t::tws;
int idx = blockIdx.x;
if constexpr (reverted_idx)
idx = gridDim.x - blockIdx.x - 1;
return_t results[elems_per_thread];
args_t args[elems_per_thread];