mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Code Refactoring for getting start and stride from global ranks, this function can be used in different collective backend. Differential Revision: D69555405 Pull Request resolved: https://github.com/pytorch/pytorch/pull/147230 Approved by: https://github.com/kwen2501
65 lines
1.7 KiB
C++
65 lines
1.7 KiB
C++
#include <torch/csrc/distributed/c10d/Utils.hpp>
|
|
|
|
#include <cstring>
|
|
|
|
namespace c10d {
|
|
|
|
std::vector<at::Tensor> getTensorShapes(
|
|
const std::vector<at::Tensor>& tensors) {
|
|
std::vector<at::Tensor> shapeTensors;
|
|
shapeTensors.reserve(tensors.size());
|
|
for (const auto& tensor : tensors) {
|
|
// Use `at::tensor()` to copy the data underlying `sizes()` since it may be
|
|
// released elsewhere.
|
|
at::Tensor shapesTensor =
|
|
at::tensor(tensor.sizes(), at::TensorOptions().dtype(at::kLong));
|
|
shapeTensors.emplace_back(std::move(shapesTensor));
|
|
}
|
|
return shapeTensors;
|
|
}
|
|
|
|
size_t getTensorsNumel(const std::vector<at::Tensor>& tensors) {
|
|
size_t numel = 0;
|
|
for (auto& tensor : tensors) {
|
|
numel += tensor.numel();
|
|
}
|
|
return numel;
|
|
}
|
|
|
|
void getGlobalRankStartAndStride(
|
|
const std::vector<uint64_t>& globalRanksInGroup,
|
|
int& globalRankStart,
|
|
int& globalRankStride) {
|
|
if (globalRanksInGroup.empty()) {
|
|
globalRankStart = 0;
|
|
} else {
|
|
globalRankStart = static_cast<int>(globalRanksInGroup[0]);
|
|
}
|
|
|
|
if (globalRanksInGroup.empty()) {
|
|
globalRankStride = 1;
|
|
} else if (globalRanksInGroup.size() == 1) {
|
|
globalRankStride = 0;
|
|
} else {
|
|
bool ranksAreStrided = true;
|
|
auto startRank = globalRanksInGroup[0];
|
|
auto stride = globalRanksInGroup[1] - globalRanksInGroup[0];
|
|
for (std::vector<uint64_t>::size_type i = 0; i < globalRanksInGroup.size();
|
|
i++) {
|
|
if (globalRanksInGroup[i] != startRank + i * stride) {
|
|
ranksAreStrided = false;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (ranksAreStrided) {
|
|
globalRankStride =
|
|
static_cast<int>(globalRanksInGroup[1] - globalRanksInGroup[0]);
|
|
} else {
|
|
globalRankStride = -1;
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace c10d
|