pytorch/torch/csrc/autograd/functions/utils.h
Edward Z. Yang 1f3ff5ced2 Miscellaneous documentation around autograd. (#1577)
* Miscellaneous documentation around autograd.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-05-17 19:19:24 -04:00

38 lines
1.2 KiB
C++

#pragma once
#include <functional>
#include <memory>
#include <array>
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/variable.h"
namespace torch { namespace autograd {
using function_constructor = std::function<std::shared_ptr<Function>(FunctionFlags)>;
template<typename ...Args>
inline variable_list as_variable_list(Args&& ... args) {
std::array<variable_list::value_type, sizeof...(args)> arr = { {std::move(args)...} };
return variable_list(std::make_move_iterator(arr.begin()),
std::make_move_iterator(arr.end()));
}
template<typename ...Args>
inline tensor_list as_tensor_list(Args&& ... args) {
std::array<tensor_list::value_type, sizeof...(args)> arr = { {std::move(args)...} };
return tensor_list(std::make_move_iterator(arr.begin()),
std::make_move_iterator(arr.end()));
}
/**
* Wraps the tensor outputs in variables, and if necessary (i.e., none of the
* inputs are volatile), uses the function ctr and inputs to create a grad_fn
* for each of them.
*/
variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs,
function_constructor ctr);
}}