mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
28 lines
742 B
C++
28 lines
742 B
C++
#include "torch/csrc/autograd/functions/utils.h"
|
|
|
|
namespace torch { namespace autograd {
|
|
|
|
variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs,
|
|
function_constructor ctr) {
|
|
auto flags = Function::flags(inputs);
|
|
variable_list result;
|
|
result.reserve(outputs.size());
|
|
if (flags.is_volatile) {
|
|
for (auto& output : outputs) {
|
|
result.emplace_back(Variable::of(std::move(output), true));
|
|
}
|
|
} else {
|
|
auto grad_fn = ctr(std::move(flags));
|
|
for (auto& output : outputs) {
|
|
if (output) {
|
|
result.emplace_back(std::make_shared<Variable>(std::move(output), grad_fn));
|
|
} else {
|
|
result.emplace_back(nullptr);
|
|
}
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
}}
|