pytorch/torch/csrc/autograd/functions/utils.cpp
2017-05-01 16:44:56 -04:00

28 lines
742 B
C++

#include "torch/csrc/autograd/functions/utils.h"
namespace torch { namespace autograd {
variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs,
function_constructor ctr) {
auto flags = Function::flags(inputs);
variable_list result;
result.reserve(outputs.size());
if (flags.is_volatile) {
for (auto& output : outputs) {
result.emplace_back(Variable::of(std::move(output), true));
}
} else {
auto grad_fn = ctr(std::move(flags));
for (auto& output : outputs) {
if (output) {
result.emplace_back(std::make_shared<Variable>(std::move(output), grad_fn));
} else {
result.emplace_back(nullptr);
}
}
}
return result;
}
}}