mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73749 Unused parameters cause compiler warnings which distract from real issues. Let's remove unused parameters! Test Plan: Sandcastle Reviewed By: swolchok, ngimel Differential Revision: D34567731 fbshipit-source-id: 2e42301a29a8e1014ac8ab429588bb773db58850 (cherry picked from commit 3eda4743991328d532194efd0fe3d127a294343d)
486 lines
15 KiB
C++
486 lines
15 KiB
C++
#pragma once
|
|
|
|
#include <c10/util/Array.h>
|
|
#include <c10/util/TypeList.h>
|
|
#include <array>
|
|
#include <functional>
|
|
#include <type_traits>
|
|
|
|
namespace c10 {
|
|
namespace guts {
|
|
|
|
/**
|
|
* Access information about result type or arguments from a function type.
|
|
* Example:
|
|
* using A = function_traits<int (float, double)>::return_type // A == int
|
|
* using A = function_traits<int (float, double)>::parameter_types::tuple_type
|
|
* // A == tuple<float, double>
|
|
*/
|
|
template <class Func>
|
|
struct function_traits {
|
|
static_assert(
|
|
!std::is_same<Func, Func>::value,
|
|
"In function_traits<Func>, Func must be a plain function type.");
|
|
};
|
|
template <class Result, class... Args>
|
|
struct function_traits<Result(Args...)> {
|
|
using func_type = Result(Args...);
|
|
using return_type = Result;
|
|
using parameter_types = typelist::typelist<Args...>;
|
|
static constexpr auto number_of_parameters = sizeof...(Args);
|
|
};
|
|
|
|
/**
|
|
* infer_function_traits: creates a `function_traits` type for a simple
|
|
* function (pointer) or functor (lambda/struct). Currently does not support
|
|
* class methods.
|
|
*/
|
|
|
|
template <typename Functor>
|
|
struct infer_function_traits {
|
|
using type = function_traits<
|
|
c10::guts::detail::strip_class_t<decltype(&Functor::operator())>>;
|
|
};
|
|
|
|
template <typename Result, typename... Args>
|
|
struct infer_function_traits<Result (*)(Args...)> {
|
|
using type = function_traits<Result(Args...)>;
|
|
};
|
|
|
|
template <typename Result, typename... Args>
|
|
struct infer_function_traits<Result(Args...)> {
|
|
using type = function_traits<Result(Args...)>;
|
|
};
|
|
|
|
template <typename T>
|
|
using infer_function_traits_t = typename infer_function_traits<T>::type;
|
|
|
|
/**
|
|
* make_function_traits: creates a `function_traits` type given a Return type
|
|
* and a typelist of Argument types
|
|
*
|
|
* Example:
|
|
* bool f(int, int);
|
|
*
|
|
* infer_function_traits_t<f> == make_function_traits_t<bool,
|
|
* typelist::typelist<int, int>>
|
|
*/
|
|
template <typename Result, typename ArgList>
|
|
struct make_function_traits {
|
|
static_assert(
|
|
false_t<ArgList>::value,
|
|
"In guts::make_function_traits<Result, TypeList>, the ArgList argument must be typelist<...>.");
|
|
};
|
|
|
|
template <typename Result, typename... Args>
|
|
struct make_function_traits<Result, typelist::typelist<Args...>> {
|
|
using type = function_traits<Result(Args...)>;
|
|
};
|
|
|
|
template <typename Result, typename ArgList>
|
|
using make_function_traits_t =
|
|
typename make_function_traits<Result, ArgList>::type;
|
|
|
|
/**
|
|
* Use extract_arg_by_filtered_index to return the i-th argument whose
|
|
* type fulfills a given type trait. The argument itself is perfectly forwarded.
|
|
*
|
|
* Example:
|
|
* std::string arg1 = "Hello";
|
|
* std::string arg2 = "World";
|
|
* std::string&& result = extract_arg_by_filtered_index<is_string, 1>(0,
|
|
* arg1, 2.0, std::move(arg2));
|
|
*
|
|
* Warning: Taking the result by rvalue reference can cause segfaults because
|
|
* ownership will not be passed on from the original reference. The original
|
|
* reference dies after the expression and the resulting
|
|
*/
|
|
namespace detail {
|
|
template <
|
|
template <class>
|
|
class Condition,
|
|
size_t index,
|
|
class Enable,
|
|
class... Args>
|
|
struct extract_arg_by_filtered_index_;
|
|
template <
|
|
template <class>
|
|
class Condition,
|
|
size_t index,
|
|
class Head,
|
|
class... Tail>
|
|
struct extract_arg_by_filtered_index_<
|
|
Condition,
|
|
index,
|
|
std::enable_if_t<!Condition<Head>::value>,
|
|
Head,
|
|
Tail...> {
|
|
static decltype(auto) call(Head&& /*head*/, Tail&&... tail) {
|
|
return extract_arg_by_filtered_index_<Condition, index, void, Tail...>::
|
|
call(std::forward<Tail>(tail)...);
|
|
}
|
|
};
|
|
template <
|
|
template <class>
|
|
class Condition,
|
|
size_t index,
|
|
class Head,
|
|
class... Tail>
|
|
struct extract_arg_by_filtered_index_<
|
|
Condition,
|
|
index,
|
|
std::enable_if_t<Condition<Head>::value && index != 0>,
|
|
Head,
|
|
Tail...> {
|
|
static decltype(auto) call(Head&& /*head*/, Tail&&... tail) {
|
|
return extract_arg_by_filtered_index_<Condition, index - 1, void, Tail...>::
|
|
call(std::forward<Tail>(tail)...);
|
|
}
|
|
};
|
|
template <template <class> class Condition, size_t index>
|
|
struct extract_arg_by_filtered_index_<Condition, index, void> {
|
|
static void call() {
|
|
static_assert(
|
|
index != index, "extract_arg_by_filtered_index out of range.");
|
|
}
|
|
};
|
|
template <
|
|
template <class>
|
|
class Condition,
|
|
size_t index,
|
|
class Head,
|
|
class... Tail>
|
|
struct extract_arg_by_filtered_index_<
|
|
Condition,
|
|
index,
|
|
std::enable_if_t<Condition<Head>::value && index == 0>,
|
|
Head,
|
|
Tail...> {
|
|
static decltype(auto) call(Head&& head, Tail&&... /*tail*/) {
|
|
return std::forward<Head>(head);
|
|
}
|
|
};
|
|
} // namespace detail
|
|
template <template <class> class Condition, size_t index, class... Args>
|
|
decltype(auto) extract_arg_by_filtered_index(Args&&... args) {
|
|
static_assert(
|
|
is_type_condition<Condition>::value,
|
|
"In extract_arg_by_filtered_index, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member.");
|
|
return detail::
|
|
extract_arg_by_filtered_index_<Condition, index, void, Args...>::call(
|
|
std::forward<Args>(args)...);
|
|
}
|
|
|
|
/**
|
|
* Use filter_map to map a subset of the arguments to values.
|
|
* The subset is defined by type traits, and will be evaluated at compile time.
|
|
* At runtime, it will just loop over the pre-filtered arguments to create an
|
|
* std::array.
|
|
*
|
|
* Example:
|
|
* std::array<double, 2> result = filter_map<double, std::is_integral>([] (auto
|
|
* a) {return (double)a;}, 3, "bla", 4);
|
|
* // result == {3.0, 4.0}
|
|
*/
|
|
namespace detail {
|
|
|
|
template <class ResultType, size_t num_results>
|
|
struct filter_map_ {
|
|
template <
|
|
template <class>
|
|
class Condition,
|
|
class Mapper,
|
|
class... Args,
|
|
size_t... INDEX>
|
|
static guts::array<ResultType, num_results> call(
|
|
const Mapper& mapper,
|
|
std::index_sequence<INDEX...>,
|
|
Args&&... args) {
|
|
return guts::array<ResultType, num_results>{
|
|
mapper(extract_arg_by_filtered_index<Condition, INDEX>(
|
|
std::forward<Args>(args)...))...};
|
|
}
|
|
};
|
|
template <class ResultType>
|
|
struct filter_map_<ResultType, 0> {
|
|
template <
|
|
template <class>
|
|
class Condition,
|
|
class Mapper,
|
|
class... Args,
|
|
size_t... INDEX>
|
|
static guts::array<ResultType, 0> call(
|
|
const Mapper& /*mapper*/,
|
|
std::index_sequence<INDEX...>,
|
|
Args&&... /*args*/) {
|
|
return guts::array<ResultType, 0>{};
|
|
}
|
|
};
|
|
} // namespace detail
|
|
|
|
template <
|
|
class ResultType,
|
|
template <class>
|
|
class Condition,
|
|
class Mapper,
|
|
class... Args>
|
|
decltype(auto) filter_map(const Mapper& mapper, Args&&... args) {
|
|
static_assert(
|
|
is_type_condition<Condition>::value,
|
|
"In filter_map<Result, Condition>, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member.");
|
|
|
|
static constexpr size_t num_results =
|
|
typelist::count_if<Condition, typelist::typelist<Args...>>::value;
|
|
return detail::filter_map_<ResultType, num_results>::
|
|
template call<Condition, Mapper, Args...>(
|
|
mapper,
|
|
std::make_index_sequence<num_results>(),
|
|
std::forward<Args>(args)...);
|
|
}
|
|
|
|
/**
|
|
* make_offset_index_sequence<Start, N>
|
|
* Like make_index_sequence<N>, but starting from Start instead of 0.
|
|
*
|
|
* Example:
|
|
* make_offset_index_sequence<10, 3> == std::index_sequence<10, 11, 12>
|
|
*/
|
|
template <size_t Start, size_t N, size_t... Is>
|
|
struct make_offset_index_sequence_impl
|
|
: make_offset_index_sequence_impl<Start, N - 1, Start + N - 1, Is...> {
|
|
static_assert(
|
|
static_cast<int>(Start) >= 0,
|
|
"make_offset_index_sequence: Start < 0");
|
|
static_assert(static_cast<int>(N) >= 0, "make_offset_index_sequence: N < 0");
|
|
};
|
|
|
|
template <size_t Start, size_t... Is>
|
|
struct make_offset_index_sequence_impl<Start, 0, Is...> {
|
|
typedef std::index_sequence<Is...> type;
|
|
};
|
|
|
|
template <size_t Start, size_t N>
|
|
using make_offset_index_sequence =
|
|
typename make_offset_index_sequence_impl<Start, N>::type;
|
|
|
|
/**
|
|
* Use tuple_elements to extract a position-indexed subset of elements
|
|
* from the argument tuple into a result tuple.
|
|
*
|
|
* Example:
|
|
* std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0);
|
|
* std::tuple<int, double> result = tuple_elements(t, std::index_sequence<0,
|
|
* 2>());
|
|
*/
|
|
template <class Tuple, size_t... Is>
|
|
constexpr auto tuple_elements(Tuple t, std::index_sequence<Is...>) {
|
|
return std::tuple<std::tuple_element_t<Is, Tuple>...>(std::get<Is>(t)...);
|
|
}
|
|
|
|
/**
|
|
* Use tuple_take to extract the first or last n elements from the argument
|
|
* tuple into a result tuple.
|
|
*
|
|
* Example:
|
|
* std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0);
|
|
* std::tuple<int, const char*> first_two = tuple_take<decltype(t), 2>(t);
|
|
* std::tuple<const char*, double> last_two = tuple_take<decltype(t), -2>(t);
|
|
*/
|
|
template <class Tuple, int N, class Enable = void>
|
|
struct TupleTake {};
|
|
|
|
template <class Tuple, int N>
|
|
struct TupleTake<Tuple, N, std::enable_if_t<N >= 0, void>> {
|
|
static auto call(Tuple t) {
|
|
constexpr size_t size = std::tuple_size<Tuple>();
|
|
static_assert(N <= size, "tuple_take: N > size");
|
|
return tuple_elements(t, std::make_index_sequence<N>{});
|
|
}
|
|
};
|
|
|
|
template <class Tuple, int N>
|
|
struct TupleTake < Tuple,
|
|
N, std::enable_if_t<N<0, void>> {
|
|
static auto call(Tuple t) {
|
|
constexpr size_t size = std::tuple_size<Tuple>();
|
|
static_assert(-N <= size, "tuple_take: -N > size");
|
|
return tuple_elements(t, make_offset_index_sequence<size + N, -N>{});
|
|
}
|
|
};
|
|
|
|
template <class Tuple, int N>
|
|
auto tuple_take(Tuple t) {
|
|
return TupleTake<Tuple, N>::call(t);
|
|
}
|
|
|
|
/**
|
|
* Use tuple_slice to extract a contiguous subtuple from the argument.
|
|
*
|
|
* Example:
|
|
* std::tuple<int, const char*, double, bool> t = std::make_tuple(0,
|
|
* "HEY", 2.0, false); std::tuple<int, const char*> middle_two =
|
|
* tuple_slice<decltype(t), 1, 2>(t);
|
|
*/
|
|
template <class Tuple, size_t Start, size_t N>
|
|
constexpr auto tuple_slice(Tuple t) {
|
|
constexpr size_t size = std::tuple_size<Tuple>();
|
|
static_assert(Start + N <= size, "tuple_slice: Start + N > size");
|
|
return tuple_elements(t, make_offset_index_sequence<Start, N>{});
|
|
}
|
|
|
|
/**
|
|
* Use tuple_map to run a mapping function over a tuple to get a new tuple.
|
|
*
|
|
* Example 1:
|
|
* auto result = tuple_map(std::tuple<int32_t, int32_t, int32_t>(3, 4, 5), []
|
|
* (int32_t a) -> int16_t {return a+1;});
|
|
* // result == std::tuple<int16_t, int16_t, int16_t>(4, 5, 6)
|
|
*
|
|
* Example 2:
|
|
* struct Mapper {
|
|
* std::string operator()(int32_t a) const {
|
|
* return std::to_string(a);
|
|
* }
|
|
* int64_t operator()(const std::string& a) const {
|
|
* return atoi(a.c_str());
|
|
* }
|
|
* };
|
|
* auto result = tuple_map(std::tuple<int32_t, std::string>(3, "4"),
|
|
* Mapper());
|
|
* // result == std::tuple<std::string, int64_t>("3", 4)
|
|
*
|
|
* Example 3:
|
|
* struct A final {
|
|
* int32_t func() {
|
|
* return 5;
|
|
* }
|
|
* };
|
|
* struct B final {
|
|
* std::string func() {
|
|
* return "5";
|
|
* }
|
|
* };
|
|
* auto result = tuple_map(std::make_tuple(A(), B()), [] (auto a) { return
|
|
* a.func(); });
|
|
* // result == std::tuple<int32_t, std::string>(5, "5");
|
|
*/
|
|
namespace detail {
|
|
template <class Mapper, class... Args, size_t... Indices>
|
|
auto tuple_map(
|
|
std::tuple<Args...>&& tuple,
|
|
const Mapper& mapper,
|
|
std::index_sequence<Indices...>) {
|
|
return std::tuple<decltype(mapper(std::forward<Args>(std::get<Indices>(
|
|
tuple))))...>(mapper(std::forward<Args>(std::get<Indices>(tuple)))...);
|
|
}
|
|
} // namespace detail
|
|
|
|
template <class Mapper, class... Args>
|
|
auto tuple_map(std::tuple<Args...>&& tuple, const Mapper& mapper) {
|
|
return detail::tuple_map(
|
|
std::move(tuple), mapper, std::index_sequence_for<Args...>());
|
|
}
|
|
|
|
/**
|
|
* tuple_concat concatenates several tuples into one.
|
|
*/
|
|
|
|
namespace detail {
|
|
// extract_tuple_element_by_index is a helper that takes a list of tuples and
|
|
// extracts the i-th element in a flattened view of the tuples. Example:
|
|
// extract_tuple_element_by_index<3>(tuple(2,3), tuple(4,5), tuple(6,7)) == 5.
|
|
|
|
template <
|
|
size_t index,
|
|
class HeadTuple,
|
|
class... TailTuples,
|
|
std::enable_if_t<
|
|
index<std::tuple_size<HeadTuple>::value, int> = 0> decltype(auto)
|
|
extract_tuple_element_by_index(
|
|
HeadTuple&& head_tuple,
|
|
TailTuples&&... /*tail_tuples*/) {
|
|
// TODO if constexpr instead of enable_if
|
|
return std::get<index>(std::forward<HeadTuple>(head_tuple));
|
|
}
|
|
|
|
template <
|
|
size_t index,
|
|
class HeadTuple,
|
|
class... TailTuples,
|
|
std::enable_if_t<index >= std::tuple_size<HeadTuple>::value, int> = 0>
|
|
decltype(auto) extract_tuple_element_by_index(
|
|
HeadTuple&& /*head_tuple*/,
|
|
TailTuples&&... tail_tuples) {
|
|
// TODO if constexpr instead of enable_if
|
|
return extract_tuple_element_by_index<
|
|
index - std::tuple_size<HeadTuple>::value,
|
|
TailTuples...>(std::forward<TailTuples>(tail_tuples)...);
|
|
}
|
|
|
|
static_assert(
|
|
std::is_same<
|
|
int&&,
|
|
decltype(extract_tuple_element_by_index<2>(
|
|
std::tuple<int32_t>(2),
|
|
std::tuple<int32_t&&, int32_t>(std::declval<int32_t>(), 3)))>::
|
|
value,
|
|
"extract_tuple_element_by_index should return rvalue references if the tuple contains them. It should not move them into a value");
|
|
|
|
template <class ConcatenatedTuple, class... Tuples, size_t... ElementIndices>
|
|
auto tuple_concat(Tuples&&... tuples, std::index_sequence<ElementIndices...>) {
|
|
return ConcatenatedTuple(extract_tuple_element_by_index<ElementIndices>(
|
|
std::forward<Tuples>(tuples)...)...);
|
|
}
|
|
} // namespace detail
|
|
|
|
template <class... Tuples>
|
|
auto tuple_concat(Tuples&&... tuples) {
|
|
using flattened_types =
|
|
guts::typelist::concat_t<guts::typelist::from_tuple_t<Tuples>...>;
|
|
using concatenated_tuple = guts::typelist::to_tuple_t<flattened_types>;
|
|
constexpr size_t num_elements = guts::typelist::size<flattened_types>::value;
|
|
return detail::tuple_concat<concatenated_tuple, Tuples...>(
|
|
std::forward<Tuples>(tuples)...,
|
|
std::make_index_sequence<num_elements>());
|
|
}
|
|
|
|
/**
|
|
* Concatenate multiple integer sequences
|
|
* Example:
|
|
* concat_iseq_t<std::index_sequence<2, 5, 3>, std::index_sequence<4, 2>,
|
|
* std::index_sequence<5>>
|
|
* == std::index_sequence<2, 5, 3, 4, 2, 5>
|
|
*/
|
|
template <class... ISeqs>
|
|
struct concat_iseq {
|
|
static_assert(
|
|
false_t<ISeqs...>::value,
|
|
"In concat_iseq<T1, ...>, the T arguments each must be std::integer_sequence<...> with the same IntType.");
|
|
};
|
|
template <>
|
|
struct concat_iseq<> {
|
|
using type = std::index_sequence<>;
|
|
};
|
|
template <class IntType, IntType... Indices>
|
|
struct concat_iseq<std::integer_sequence<IntType, Indices...>> {
|
|
using type = std::integer_sequence<IntType, Indices...>;
|
|
};
|
|
template <
|
|
class IntType,
|
|
IntType... Head1Indices,
|
|
IntType... Head2Indices,
|
|
class... TailISeqs>
|
|
struct concat_iseq<
|
|
std::integer_sequence<IntType, Head1Indices...>,
|
|
std::integer_sequence<IntType, Head2Indices...>,
|
|
TailISeqs...> {
|
|
using type = typename concat_iseq<
|
|
std::integer_sequence<IntType, Head1Indices..., Head2Indices...>,
|
|
TailISeqs...>::type;
|
|
};
|
|
template <class... ISeqs>
|
|
using concat_iseq_t = typename concat_iseq<ISeqs...>::type;
|
|
|
|
} // namespace guts
|
|
} // namespace c10
|