pytorch/torch/csrc/utils/variadic.h
Peter Goldsborough cb0e72e00d Add registerOperator overloads that infer the schema (#10048)
Summary:
This PR adds a way to infer the JIT/script schema of a function from its signature, and then create an operator from the schema and implementation. The implementation function is wrapped into another function, which pops values from the stack into an argument tuple, then invokes the function and pushes the return value back onto the stack, sometimes unpacking the return value if it is a tuple.

Currently the method is called `createOperator`. We may want to think of a nicer way of registering ops in tandem with `RegisterOperators`. It might be very cumbersome to add a template constructor to `Operator`, so maybe we can come up with a chaining method on `RegisterOperators` like `RegisterOperators(schema, func).op(schema.func).op(schema, func)` -- it has to work at startup time (for a static variable) though. We can solve this in another PR.

zdevito apaszke smessmer dzhulgakov
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10048

Differential Revision: D9125975

Pulled By: goldsborough

fbshipit-source-id: de9e59888757573284a43787ae5d94384bfe8f9a
2018-08-03 11:45:49 -07:00

200 lines
5.8 KiB
C++

#pragma once
#include <ATen/ATen.h>
#include "torch/csrc/autograd/variable.h"
#include <cstdint>
#include <tuple>
#include <type_traits>
#include <utility>
namespace torch {
// This class allows you to write variadic functions which
// call a (possibly overloaded) function on each argument,
// in order. This is most commonly used in autogenerated code,
// where it is convenient to have a function that can uniformly
// take arguments of different types. If your arguments
// are homogenous consider using a std::initializer_list instead.
template <typename F>
struct IterArgs {
template <typename... Args>
inline F& apply() {
return self();
}
// NB: Use perfect forwarding here, otherwise we'll make value
// copies of all arguments!
template <typename T, typename... Args>
inline F& apply(T&& arg, Args&&... args) {
self()(std::forward<T>(arg));
if (self().short_circuit()) {
return self();
} else {
return apply(std::forward<Args>(args)...);
}
}
// Here are some handy overloads which provide sensible
// defaults for container-like structures that one might
// be interested in recursing into. You can enable them
// by adding:
//
// using IterArgs<YourStructName>::operator()
//
// to your struct. These are not enabled by default because
// you may be able to process these structures more efficiently
// than handling them one-by-one.
template <typename T>
void operator()(at::ArrayRef<T> args) {
for (const auto& arg : args) {
self()(arg);
if (short_circuit())
return;
}
}
// NB: we need to specify std::vector manually as C++ won't
// do an implicit conversion to make a template deduction go through.
template <typename T>
void operator()(const std::vector<T>& args) {
self()(at::ArrayRef<T>{args});
}
bool short_circuit() {
return false;
}
private:
inline F& self() {
return *static_cast<F*>(this);
}
};
struct CountTensors : IterArgs<CountTensors> {
size_t out = 0;
void operator()(const at::Tensor& x) {
out += 1;
}
void operator()(at::ArrayRef<at::Tensor> xs) {
out += xs.size();
}
};
template <typename... Args>
size_t count_tensors(Args&&... args) {
return CountTensors().apply(std::forward<Args>(args)...).out;
}
struct CountVariables : IterArgs<CountVariables> {
size_t out = 0;
void operator()(const autograd::Variable& x) {
out += 1;
}
void operator()(at::ArrayRef<autograd::Variable> xs) {
out += xs.size();
}
};
template <typename... Args>
inline size_t count_variables(Args&&... args) {
return CountVariables().apply(std::forward<Args>(args)...).out;
}
//===----------------------------------------------------------------------===//
// std::index_sequence shim for C++11
//===----------------------------------------------------------------------===//
// A container of type-template parameter indices.
template <size_t... Is>
struct Indices {};
// Decrements the index N, adds N-1 to the list of indices and forwards
// whatever we arleady have.
template <size_t N, size_t... Is>
struct MakeIndices : MakeIndices<N - 1, N - 1, Is...> {};
// Partial specialization that forms our base case. When N is zero, we stop
// and define a typedef that will be visible to earlier classes due to
// inheritance. The typedef we define is an index list containing the numbers
// 0 through N-1.
template <size_t... Is>
struct MakeIndices<0, Is...> {
using indices = Indices<Is...>;
};
//===----------------------------------------------------------------------===//
// Utilities
//===----------------------------------------------------------------------===//
template <bool value, typename T = void>
using enable_if_t = typename std::enable_if<value, T>::type;
template <bool value, typename T = void>
using disable_if_t = enable_if_t<!value, T>;
template <typename T>
using decay_t = typename std::decay<T>::type;
namespace detail {
template <bool...>
struct pack;
} // namespace detail
template <bool... values>
struct all_of : std::is_same<
detail::pack<values..., true>,
detail::pack<true, values...>> {};
template <bool...>
struct any_of;
template <>
struct any_of<> : std::false_type {};
template <bool head, bool... tail>
struct any_of<head, tail...> {
static constexpr bool value = head || any_of<tail...>::value;
};
template <bool... values>
struct none_of {
static constexpr bool value = !any_of<values...>::value;
};
template <bool... values>
using enable_if_all_of_t = enable_if_t<all_of<values...>::value>;
template <typename T, typename... Ts>
using disable_if_contains_t =
enable_if_all_of_t<(!std::is_same<T, decay_t<Ts>>::value)...>;
template <typename Function, typename... Ts>
void apply(Function function, Ts&&... ts) {
// https://stackoverflow.com/questions/13978916/inserting-a-variadic-argument-list-into-a-vector
// Creates a dummy array, so that each function call is evaluated in order.
// `(function(), 0)` is because `function` should (!) return `void`, so
// according to the comma operator, it is evaluated and its result (`void`)
// is discarded. Then the zero is evaluated and used as an element in the
// array. The first zero ensures the array is not empty.
int _[]{0, (function(std::forward<Ts>(ts)), 0)...};
(void)_;
}
template <typename... Ts, typename Function, typename Accessor>
auto unpack(Function function, Accessor accessor)
-> decltype(function(std::declval<Ts>()...)) {
return unpack<Ts...>(
std::move(function),
std::move(accessor),
typename MakeIndices<sizeof...(Ts)>::indices());
}
template <typename... Ts, typename Function, typename Accessor, size_t... Is>
auto unpack(Function function, Accessor accessor, Indices<Is...>)
-> decltype(function(std::declval<Ts>()...)) {
return function(accessor.template operator()<Ts>(Is)...);
}
} // namespace torch