pytorch/torch/csrc/autograd/input_buffer.cpp
2017-05-01 16:44:56 -04:00

51 lines
1.3 KiB
C++

#include "torch/csrc/autograd/input_buffer.h"
#include "torch/csrc/autograd/functions/basic_ops.h"
#include "torch/csrc/utils/auto_gpu.h"
namespace torch { namespace autograd {
InputBuffer::InputBuffer(size_t size)
: buffer(size)
{}
void InputBuffer::add(size_t pos, std::shared_ptr<Variable>&& var) {
if (!var) {
return;
}
auto& item = buffer[pos];
auto& saved_var_ptr = item.first;
if (!saved_var_ptr) {
auto version = **var->version_counter;
buffer[pos] = std::make_pair<>(std::move(var), version);
} else {
auto add_fn = std::make_shared<Add>();
variable_list result = add_fn->apply({item.first, var});
buffer[pos] = std::make_pair<>(std::move(result[0]), 0);
}
}
auto InputBuffer::device() const -> int {
for (auto& pair : buffer) {
if (pair.first) {
return pair.first->data->getDevice();
}
}
return -1;
}
auto InputBuffer::variables(InputBuffer&& g) -> std::vector<std::shared_ptr<Variable>> {
InputBuffer _buffer = std::move(g);
auto& buffer = _buffer.buffer;
int size = buffer.size();
std::vector<std::shared_ptr<Variable>> result;
result.reserve(size);
for (int i = 0; i != size; ++i) {
auto var_ptr = buffer[i].first;
result.emplace_back(var_ptr ? std::move(buffer[i].first) : nullptr);
}
return result;
}
}} // namespace torch::autograd