mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This PR continues to clean clang-tidy warnings in torch/csrc/distributed/c10d, following #124701. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124987 Approved by: https://github.com/malfet
30 lines
781 B
C++
30 lines
781 B
C++
#include <torch/csrc/distributed/c10d/Utils.hpp>
|
|
|
|
#include <cstring>
|
|
|
|
namespace c10d {
|
|
|
|
std::vector<at::Tensor> getTensorShapes(
|
|
const std::vector<at::Tensor>& tensors) {
|
|
std::vector<at::Tensor> shapeTensors;
|
|
shapeTensors.reserve(tensors.size());
|
|
for (const auto& tensor : tensors) {
|
|
// Use `at::tensor()` to copy the data underlying `sizes()` since it may be
|
|
// released elsewhere.
|
|
at::Tensor shapesTensor =
|
|
at::tensor(tensor.sizes(), at::TensorOptions().dtype(at::kLong));
|
|
shapeTensors.emplace_back(std::move(shapesTensor));
|
|
}
|
|
return shapeTensors;
|
|
}
|
|
|
|
size_t getTensorsNumel(const std::vector<at::Tensor>& tensors) {
|
|
size_t numel = 0;
|
|
for (auto& tensor : tensors) {
|
|
numel += tensor.numel();
|
|
}
|
|
return numel;
|
|
}
|
|
|
|
} // namespace c10d
|