mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
remove NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) remove NOLINTNEXTLINE(performance-move-const-arg) remove NOLINTNEXTLINE(performance-no-automatic-move) Pull Request resolved: https://github.com/pytorch/pytorch/pull/92287 Approved by: https://github.com/albanD
143 lines
4.0 KiB
C++
143 lines
4.0 KiB
C++
#include <torch/csrc/autograd/functions/comm.h>
|
|
|
|
#include <ATen/core/functional.h>
|
|
#include <torch/csrc/autograd/function.h>
|
|
#include <torch/csrc/autograd/functions/utils.h>
|
|
#include <torch/csrc/autograd/variable.h>
|
|
#include <torch/csrc/cuda/comm.h>
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <c10/util/Optional.h>
|
|
|
|
#include <cstddef>
|
|
#include <memory>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace autograd {
|
|
Scatter::Scatter(
|
|
std::vector<at::Device> devices,
|
|
c10::optional<std::vector<int64_t>> chunk_sizes,
|
|
int64_t dim,
|
|
c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>> streams,
|
|
bool unsqueeze_scalars)
|
|
: devices_(std::move(devices)),
|
|
chunk_sizes_(std::move(chunk_sizes)),
|
|
dim_(dim),
|
|
streams_(std::move(streams)),
|
|
unsqueeze_scalars_(unsqueeze_scalars) {}
|
|
|
|
Scatter::~Scatter() = default;
|
|
|
|
variable_list Scatter::apply(variable_list&& inputs) {
|
|
AT_ASSERT(inputs.size() == 1);
|
|
auto& input = inputs.front();
|
|
|
|
std::shared_ptr<Node> grad_fn;
|
|
if (compute_requires_grad(input)) {
|
|
grad_fn =
|
|
std::make_shared<Gather>(/*destination_device=*/input.device(), dim_);
|
|
grad_fn->set_next_edges(collect_next_edges(input));
|
|
}
|
|
|
|
auto device_indices = fmap(devices_, [](const at::Device& device) -> int64_t {
|
|
return device.index();
|
|
});
|
|
auto tensors = torch::cuda::scatter(
|
|
std::move(input), device_indices, chunk_sizes_, dim_, streams_);
|
|
|
|
std::vector<Variable> variables;
|
|
variables.reserve(tensors.size());
|
|
for (auto& tensor : tensors) {
|
|
AT_ASSERT(tensor.defined());
|
|
if (unsqueeze_scalars_) {
|
|
AT_ASSERT(tensor.dim() == 1 && tensor.numel() == 1);
|
|
variables.push_back(tensor[0]);
|
|
} else {
|
|
variables.push_back(std::move(tensor));
|
|
}
|
|
}
|
|
|
|
if (grad_fn) {
|
|
set_history(variables, grad_fn);
|
|
}
|
|
|
|
return variables;
|
|
}
|
|
|
|
Gather::Gather(const at::Device& destination_device, int64_t dim)
|
|
: destination_device_(destination_device), dim_(dim) {}
|
|
|
|
Gather::~Gather() = default;
|
|
|
|
variable_list Gather::apply(variable_list&& inputs) {
|
|
bool all_are_zero_dim = true;
|
|
for (const auto& input : inputs) {
|
|
TORCH_CHECK(
|
|
input.is_cuda(),
|
|
"All inputs to Gather must be CUDA tensors, got ",
|
|
input.toString());
|
|
if (input.dim() > 0) {
|
|
all_are_zero_dim = false;
|
|
}
|
|
}
|
|
|
|
const bool unsqueeze_scalars = all_are_zero_dim && dim_ == 0;
|
|
if (unsqueeze_scalars) {
|
|
TORCH_WARN(
|
|
"Was asked to gather along dimension 0, but all "
|
|
"input tensors were scalars; will instead unsqueeze "
|
|
"and return a vector.");
|
|
}
|
|
|
|
std::shared_ptr<Node> grad_fn;
|
|
// compute this before moving variables from `inputs`
|
|
if (compute_requires_grad(inputs)) {
|
|
std::vector<at::Device> source_devices;
|
|
source_devices.reserve(inputs.size());
|
|
std::vector<int64_t> input_sizes;
|
|
input_sizes.reserve(inputs.size());
|
|
for (auto& input : inputs) {
|
|
source_devices.push_back(input.device());
|
|
input_sizes.push_back(input.size(dim_));
|
|
}
|
|
grad_fn = std::make_shared<Scatter>(
|
|
std::move(source_devices),
|
|
std::move(input_sizes),
|
|
dim_,
|
|
/*streams=*/c10::nullopt,
|
|
/*unsqueeze_scalars=*/unsqueeze_scalars);
|
|
grad_fn->set_next_edges(collect_next_edges(inputs));
|
|
}
|
|
|
|
std::vector<at::Tensor> tensors;
|
|
tensors.reserve(inputs.size());
|
|
for (auto& variable : inputs) {
|
|
if (unsqueeze_scalars) {
|
|
tensors.push_back(variable.view(1));
|
|
} else {
|
|
tensors.push_back(std::move(variable));
|
|
}
|
|
}
|
|
|
|
// Disable the autograd during the actual computation
|
|
// torch::cuda::gather does not return a view or change things inplace
|
|
// so no need for extra logic here
|
|
at::Tensor variable;
|
|
{
|
|
at::AutoDispatchBelowAutograd mode;
|
|
// This is special logic for torch::cuda::gather!
|
|
const auto destination_index =
|
|
destination_device_.is_cpu() ? -1 : destination_device_.index();
|
|
variable = torch::cuda::gather(tensors, dim_, destination_index);
|
|
}
|
|
if (grad_fn) {
|
|
set_history(variable, grad_fn);
|
|
}
|
|
return {variable};
|
|
}
|
|
|
|
} // namespace autograd
|
|
} // namespace torch
|