mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] Extend vectorized elementwise kernel to more heterogenous tensor types. (#149738)
This patch extends the initial support for "vectorized templated" kernels to the following input tensor types: (BFloat16, float) (float, float16) (float16, float) Pull Request resolved: https://github.com/pytorch/pytorch/pull/149738 Approved by: https://github.com/jeffdaily
This commit is contained in:
parent
2a9e737839
commit
5af9cb12b7
|
|
@ -292,8 +292,11 @@ __global__ void vectorized_templated_elementwise_kernel(
|
|||
out_calc_t out_calc,
|
||||
loader_t loader,
|
||||
storer_t storer) {
|
||||
int remaining =
|
||||
N - vectorized_templated_config::block_work_size() * blockIdx.x;
|
||||
int remaining = N -
|
||||
vectorized_templated_config::block_work_size() *
|
||||
(gridDim.x - blockIdx.x - 1);
|
||||
constexpr bool reverted_idx = true;
|
||||
|
||||
if (remaining <
|
||||
vectorized_templated_config::block_work_size()) { // if this block handles
|
||||
// the reminder,
|
||||
|
|
@ -307,18 +310,17 @@ __global__ void vectorized_templated_elementwise_kernel(
|
|||
storer_t,
|
||||
vectorized_templated_config::elems_per_thread()>(
|
||||
data, remaining, inp_calc, out_calc, loader, storer);
|
||||
elementwise_kernel_helper(f, policy);
|
||||
elementwise_kernel_helper<reverted_idx>(f, policy);
|
||||
} else { // if this block has a full `block_work_size` data to handle, use
|
||||
// vectorized memory access
|
||||
elementwise_kernel_helper(
|
||||
f,
|
||||
memory::policies::vectorized_templated<
|
||||
vec_size,
|
||||
array_t,
|
||||
vectorized_templated_config::elems_per_thread(),
|
||||
vectorized_templated_config::num_threads(),
|
||||
OutputType,
|
||||
InputTypes...>(data));
|
||||
auto policy = memory::policies::vectorized_templated<
|
||||
vec_size,
|
||||
array_t,
|
||||
vectorized_templated_config::elems_per_thread(),
|
||||
vectorized_templated_config::num_threads(),
|
||||
OutputType,
|
||||
InputTypes...>(data);
|
||||
elementwise_kernel_helper<reverted_idx>(f, policy);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -544,19 +546,34 @@ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
|
|||
|
||||
#ifdef USE_ROCM
|
||||
namespace {
|
||||
template <typename TupleLike, size_t arity, size_t arg_num = 0>
|
||||
struct check_types {
|
||||
template <
|
||||
typename TupleLike,
|
||||
typename FirstParamTy,
|
||||
typename SecondParamTy,
|
||||
size_t arity,
|
||||
size_t arg_num = 0>
|
||||
struct check_binary_functor_types_for_specialization {
|
||||
constexpr static inline bool check() {
|
||||
if constexpr (arity != 2)
|
||||
return false;
|
||||
if constexpr (arg_num == 0) {
|
||||
using SelectedType = std::tuple_element_t<arg_num, TupleLike>;
|
||||
if constexpr (std::is_same_v<float, SelectedType>)
|
||||
return check_types<TupleLike, arity, arg_num + 1>::check();
|
||||
if constexpr (std::is_same_v<FirstParamTy, SelectedType>)
|
||||
return check_binary_functor_types_for_specialization<
|
||||
TupleLike,
|
||||
FirstParamTy,
|
||||
SecondParamTy,
|
||||
arity,
|
||||
arg_num + 1>::check();
|
||||
} else if constexpr (arg_num == 1) {
|
||||
using SelectedType2 = std::tuple_element_t<arg_num, TupleLike>;
|
||||
if constexpr (std::is_same_v<float, SelectedType2>)
|
||||
return check_types<TupleLike, arity, arg_num + 1>::check();
|
||||
if constexpr (std::is_same_v<SecondParamTy, SelectedType2>)
|
||||
return check_binary_functor_types_for_specialization<
|
||||
TupleLike,
|
||||
FirstParamTy,
|
||||
SecondParamTy,
|
||||
arity,
|
||||
arg_num + 1>::check();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
|
@ -564,8 +581,17 @@ struct check_types {
|
|||
|
||||
// Bottom case: if we got this far, assume correct type matching except
|
||||
// when there are no arguments (arity == 0).
|
||||
template <typename TupleLike, size_t arity>
|
||||
struct check_types<TupleLike, arity, arity> {
|
||||
template <
|
||||
typename TupleLike,
|
||||
typename FirstParamTy,
|
||||
typename SecondParamTy,
|
||||
size_t arity>
|
||||
struct check_binary_functor_types_for_specialization<
|
||||
TupleLike,
|
||||
FirstParamTy,
|
||||
SecondParamTy,
|
||||
arity,
|
||||
arity> {
|
||||
constexpr static inline bool check() {
|
||||
if constexpr (arity != 0)
|
||||
return true;
|
||||
|
|
@ -573,12 +599,90 @@ struct check_types<TupleLike, arity, arity> {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename TupleLike>
|
||||
struct check_types<TupleLike, 0, 0> {
|
||||
template <typename TupleLike, typename FirstParamTy, typename SecondParamTy>
|
||||
struct check_binary_functor_types_for_specialization<
|
||||
TupleLike,
|
||||
FirstParamTy,
|
||||
SecondParamTy,
|
||||
0,
|
||||
0> {
|
||||
constexpr static inline bool check() {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
// The following is a list of type specializations for vectorized_templated
|
||||
// elementwise kernel. It refers to the first and second runtime types of the
|
||||
// arguments of a binary functor.
|
||||
|
||||
constexpr std::array rt_binary_specializations = {
|
||||
std::array<c10::ScalarType, 2>(
|
||||
{c10::CppTypeToScalarType<float>::value,
|
||||
c10::CppTypeToScalarType<BFloat16>::value}),
|
||||
std::array<c10::ScalarType, 2>(
|
||||
{c10::CppTypeToScalarType<BFloat16>::value,
|
||||
c10::CppTypeToScalarType<float>::value}),
|
||||
std::array<c10::ScalarType, 2>(
|
||||
{c10::CppTypeToScalarType<float>::value,
|
||||
c10::CppTypeToScalarType<Half>::value}),
|
||||
std::array<c10::ScalarType, 2>(
|
||||
{c10::CppTypeToScalarType<Half>::value,
|
||||
c10::CppTypeToScalarType<float>::value})};
|
||||
|
||||
bool check_binary_rt_types_for_specialization(TensorIteratorBase& iter) {
|
||||
if (iter.ninputs() != 2)
|
||||
return false;
|
||||
for (auto spec : rt_binary_specializations)
|
||||
if (iter.input_dtype(0) == spec[0] && iter.input_dtype(1) == spec[1])
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
template <int arg_index>
|
||||
struct type_specialized_kernel_launcher {
|
||||
template <
|
||||
typename func_t,
|
||||
typename array_t,
|
||||
typename inp_calc_t,
|
||||
typename out_calc_t,
|
||||
typename loader_t,
|
||||
typename storer_t>
|
||||
static void apply(
|
||||
ScalarType arg0_t,
|
||||
ScalarType arg1_t,
|
||||
int64_t numel,
|
||||
func_t f,
|
||||
array_t data,
|
||||
inp_calc_t input_offset_calculator,
|
||||
out_calc_t output_offset_calculator,
|
||||
loader_t loader,
|
||||
storer_t storer) {
|
||||
using traits = function_traits<func_t>;
|
||||
using return_t = typename traits::result_type;
|
||||
if (arg0_t == rt_binary_specializations[arg_index][0] &&
|
||||
arg1_t == rt_binary_specializations[arg_index][1])
|
||||
launch_vectorized_templated_kernel<
|
||||
func_t,
|
||||
array_t,
|
||||
inp_calc_t,
|
||||
out_calc_t,
|
||||
loader_t,
|
||||
storer_t,
|
||||
return_t,
|
||||
decltype(c10::impl::ScalarTypeToCPPType<
|
||||
rt_binary_specializations[arg_index][0]>::t),
|
||||
decltype(c10::impl::ScalarTypeToCPPType<
|
||||
rt_binary_specializations[arg_index][1]>::t)>(
|
||||
numel,
|
||||
f,
|
||||
data,
|
||||
input_offset_calculator,
|
||||
output_offset_calculator,
|
||||
loader,
|
||||
storer);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
#endif
|
||||
|
||||
|
|
@ -608,43 +712,46 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
|
|||
#ifdef USE_ROCM
|
||||
// Attempt to call specialized vectorized elementwise kernel
|
||||
// that enables interleaving.
|
||||
using float_map = c10::CppTypeToScalarType<float>;
|
||||
using bfloat16_map = c10::CppTypeToScalarType<BFloat16>;
|
||||
if (iter.ninputs() == 2 && iter.input_dtype(0) == float_map::value &&
|
||||
iter.input_dtype(1) == bfloat16_map::value &&
|
||||
|
||||
if (check_binary_rt_types_for_specialization(iter) &&
|
||||
memory::can_vectorize_up_to<func_t>(data) > 1) {
|
||||
// constexpr to reduce the amount of kernels (empty) generated for
|
||||
// constexpr to reduce the amount of kernels generated for
|
||||
// vectorized templated elementwise and limit which functors are actually
|
||||
// applied to the load and store at compile time.
|
||||
using func_tuple = typename traits::ArgsTuple;
|
||||
if constexpr (
|
||||
std::is_same_v<float, arg0_t> && traits::arity == 2 &&
|
||||
check_types<func_tuple, traits::arity, 0>::check()) {
|
||||
check_binary_functor_types_for_specialization<
|
||||
func_tuple,
|
||||
float,
|
||||
float,
|
||||
traits::arity,
|
||||
/*arg_num=*/0>::check()) {
|
||||
// If we got here, we know we are in one of the specialized cases. We
|
||||
// need to translate the runtime type to a statically known type. This
|
||||
// is effectively hoisting to the host the switch over runtime type in
|
||||
// the kernel in fetch_and_cast. Loader, storer, offset calculators are
|
||||
// only needed for the reminder loop.
|
||||
auto input_offset_calculator = TrivialOffsetCalculator<traits::arity>();
|
||||
auto output_offset_calculator = TrivialOffsetCalculator<1>();
|
||||
auto loader = memory::LoadWithCast<traits::arity>(iter);
|
||||
auto storer = memory::StoreWithCast<1>(iter);
|
||||
launch_vectorized_templated_kernel<
|
||||
func_t,
|
||||
std::array<char*, ntensors>,
|
||||
decltype(input_offset_calculator),
|
||||
decltype(output_offset_calculator),
|
||||
decltype(loader),
|
||||
decltype(storer),
|
||||
float,
|
||||
float,
|
||||
BFloat16>(
|
||||
numel,
|
||||
f,
|
||||
data,
|
||||
input_offset_calculator,
|
||||
output_offset_calculator,
|
||||
loader,
|
||||
storer);
|
||||
memory::detail::static_unroll<
|
||||
type_specialized_kernel_launcher,
|
||||
rt_binary_specializations.size()>::
|
||||
with_args(
|
||||
iter.input_dtype(0),
|
||||
iter.input_dtype(1),
|
||||
numel,
|
||||
f,
|
||||
data,
|
||||
input_offset_calculator,
|
||||
output_offset_calculator,
|
||||
loader,
|
||||
storer);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
std::array<ScalarType, ntensors> dtypes;
|
||||
auto inner_strides = iter.get_inner_strides();
|
||||
std::array<int, ntensors> strides;
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ static OffsetCalculator<num_outputs> make_output_offset_calculator(const TensorI
|
|||
return OffsetCalculator<num_outputs>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
|
||||
}
|
||||
|
||||
template<typename func_t, typename policy_t>
|
||||
template <bool reverted_idx = false, typename func_t, typename policy_t>
|
||||
__device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
|
||||
using traits = function_traits<func_t>;
|
||||
using return_t = typename traits::result_type;
|
||||
|
|
@ -49,6 +49,8 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
|
|||
constexpr int elems_per_thread = policy_t::tws;
|
||||
|
||||
int idx = blockIdx.x;
|
||||
if constexpr (reverted_idx)
|
||||
idx = gridDim.x - blockIdx.x - 1;
|
||||
|
||||
return_t results[elems_per_thread];
|
||||
args_t args[elems_per_thread];
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user