#pragma once #include #include #include #include #include #include namespace torch { // This class allows you to write variadic functions which // call a (possibly overloaded) function on each argument, // in order. This is most commonly used in autogenerated code, // where it is convenient to have a function that can uniformly // take arguments of different types. If your arguments // are homogenous consider using a std::initializer_list instead. template struct IterArgs { template inline F& apply() { return self(); } // NB: Use perfect forwarding here, otherwise we'll make value // copies of all arguments! template inline F& apply(T&& arg, Args&&... args) { self()(std::forward(arg)); if (self().short_circuit()) { return self(); } else { return apply(std::forward(args)...); } } // Here are some handy overloads which provide sensible // defaults for container-like structures that one might // be interested in recursing into. You can enable them // by adding: // // using IterArgs::operator() // // to your struct. These are not enabled by default because // you may be able to process these structures more efficiently // than handling them one-by-one. template void operator()(at::ArrayRef args) { for (const auto& arg : args) { self()(arg); if (short_circuit()) return; } } // NB: we need to specify std::vector manually as C++ won't // do an implicit conversion to make a template deduction go through. template void operator()(const std::vector& args) { self()(at::ArrayRef{args}); } bool short_circuit() { return false; } private: inline F& self() { return *static_cast(this); } }; struct CountTensors : IterArgs { size_t out = 0; void operator()(const at::Tensor& x) { out += 1; } void operator()(at::ArrayRef xs) { out += xs.size(); } }; template size_t count_tensors(Args&&... args) { return CountTensors().apply(std::forward(args)...).out; } struct CountVariables : IterArgs { size_t out = 0; void operator()(const autograd::Variable& x) { out += 1; } void operator()(at::ArrayRef xs) { out += xs.size(); } }; template inline size_t count_variables(Args&&... args) { return CountVariables().apply(std::forward(args)...).out; } //===----------------------------------------------------------------------===// // std::index_sequence shim for C++11 //===----------------------------------------------------------------------===// // A container of type-template parameter indices. template struct Indices {}; // Decrements the index N, adds N-1 to the list of indices and forwards // whatever we arleady have. template struct MakeIndices : MakeIndices {}; // Partial specialization that forms our base case. When N is zero, we stop // and define a typedef that will be visible to earlier classes due to // inheritance. The typedef we define is an index list containing the numbers // 0 through N-1. template struct MakeIndices<0, Is...> { using indices = Indices; }; //===----------------------------------------------------------------------===// // Utilities //===----------------------------------------------------------------------===// template using enable_if_t = typename std::enable_if::type; template using disable_if_t = enable_if_t; template using decay_t = typename std::decay::type; namespace detail { template struct pack; } // namespace detail template struct all_of : std::is_same< detail::pack, detail::pack> {}; template struct any_of; template <> struct any_of<> : std::false_type {}; template struct any_of { static constexpr bool value = head || any_of::value; }; template struct none_of { static constexpr bool value = !any_of::value; }; template using enable_if_all_of_t = enable_if_t::value>; template using disable_if_contains_t = enable_if_all_of_t<(!std::is_same>::value)...>; template void apply(Function function, Ts&&... ts) { // https://stackoverflow.com/questions/13978916/inserting-a-variadic-argument-list-into-a-vector // Creates a dummy array, so that each function call is evaluated in order. // `(function(), 0)` is because `function` should (!) return `void`, so // according to the comma operator, it is evaluated and its result (`void`) // is discarded. Then the zero is evaluated and used as an element in the // array. The first zero ensures the array is not empty. int _[]{0, (function(std::forward(ts)), 0)...}; (void)_; } template ReturnType unpack(Function function, Accessor accessor) { return ReturnType(unpack( std::move(function), std::move(accessor), typename MakeIndices::indices())); } template ReturnType unpack(Function function, Accessor accessor, Indices) { return ReturnType(function(accessor.template operator()(Is)...)); } } // namespace torch