#include #include #include #include #include #include #include #include #include #include #include #include namespace torch { namespace autograd { Scatter::Scatter( std::vector devices, const c10::optional>& chunk_sizes, int64_t dim, const c10::optional>>& streams, bool unsqueeze_scalars) : devices_(std::move(devices)), chunk_sizes_(chunk_sizes), dim_(dim), streams_(streams), unsqueeze_scalars_(unsqueeze_scalars) {} Scatter::~Scatter() {} variable_list Scatter::apply(variable_list&& inputs) { AT_ASSERT(inputs.size() == 1); auto& input = inputs.front(); std::shared_ptr grad_fn; if (compute_requires_grad(input)) { grad_fn = std::make_shared(/*destination_device=*/input.device(), dim_); grad_fn->set_next_edges(collect_next_edges(input)); } auto device_indices = fmap(devices_, [](const at::Device& device) -> int64_t { return device.index(); }); auto tensors = torch::cuda::scatter( std::move(input), device_indices, chunk_sizes_, dim_, streams_); std::vector variables; variables.reserve(tensors.size()); for (auto& tensor : tensors) { AT_ASSERT(tensor.defined()); if (unsqueeze_scalars_) { AT_ASSERT(tensor.dim() == 1 && tensor.numel() == 1); variables.push_back(tensor[0]); } else { variables.push_back(std::move(tensor)); } } if (grad_fn) { set_history(variables, grad_fn); } return variables; } Gather::Gather(const at::Device& destination_device, int64_t dim) : destination_device_(destination_device), dim_(dim) {} Gather::~Gather() {} variable_list Gather::apply(variable_list&& inputs) { bool all_are_zero_dim = true; for (const auto& input : inputs) { TORCH_CHECK( input.is_cuda(), "All inputs to Gather must be CUDA tensors, got ", input.toString()); if (input.dim() > 0) { all_are_zero_dim = false; } } const bool unsqueeze_scalars = all_are_zero_dim && dim_ == 0; if (unsqueeze_scalars) { AT_WARN( "Was asked to gather along dimension 0, but all " "input tensors were scalars; will instead unsqueeze " "and return a vector."); } std::vector tensors; tensors.reserve(inputs.size()); for (auto& variable : inputs) { if (unsqueeze_scalars) { tensors.push_back(variable.view(1)); } else { tensors.push_back(std::move(variable)); } } std::shared_ptr grad_fn; if (compute_requires_grad(inputs)) { std::vector source_devices; std::vector input_sizes; for (auto& input : inputs) { source_devices.push_back(input.device()); input_sizes.push_back(input.size(dim_)); } grad_fn = std::make_shared( std::move(source_devices), std::move(input_sizes), dim_, /*streams=*/c10::nullopt, /*unsqueeze_scalars=*/unsqueeze_scalars); grad_fn->set_next_edges(collect_next_edges(inputs)); } // This is special logic for torch::cuda::gather! const auto destination_index = destination_device_.is_cpu() ? -1 : destination_device_.index(); auto variable = torch::cuda::gather(tensors, dim_, destination_index); if (grad_fn) { set_history(variable, grad_fn); } return {variable}; } } // namespace autograd } // namespace torch