[Distributed] [7/N] Fix clang-tidy warnings in torch/csrc/distributed/c10d (#124987)

This PR continues to clean clang-tidy warnings in torch/csrc/distributed/c10d, following #124701. In addition, libfmt dependency is added in CMake code to enable using it in the headers. The libfmt has to be added as private dependency to torch_cuda and torch_hip because they include torch/csrc/distributed/c10d/Utils.hpp which uses libfmt.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124987
Approved by: https://github.com/malfet
This commit is contained in:
cyy 2024-04-27 07:22:27 +00:00 committed by PyTorch MergeBot
parent ce503c1b40
commit b3fd94d15e
16 changed files with 96 additions and 72 deletions

View File

@ -1033,7 +1033,7 @@ elseif(USE_CUDA)
target_compile_definitions(torch_cuda PRIVATE USE_CUSPARSELT)
endif()
if(USE_NCCL)
target_link_libraries(torch_cuda PRIVATE __caffe2_nccl)
target_link_libraries(torch_cuda PRIVATE __caffe2_nccl fmt::fmt-header-only)
target_compile_definitions(torch_cuda PRIVATE USE_NCCL)
endif()
if(USE_UCC)
@ -1776,7 +1776,7 @@ if(USE_ROCM)
target_link_libraries(torch_hip PRIVATE ATEN_CUDA_FILES_GEN_LIB)
endif()
target_link_libraries(torch_hip PUBLIC torch_cpu_library ${Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS})
target_link_libraries(torch_hip PRIVATE ${Caffe2_HIP_DEPENDENCY_LIBS})
target_link_libraries(torch_hip PRIVATE ${Caffe2_HIP_DEPENDENCY_LIBS} fmt::fmt-header-only)
# Since PyTorch files contain HIP headers, this is also needed to capture the includes.
target_include_directories(torch_hip PRIVATE ${Caffe2_HIP_INCLUDE})

View File

@ -13,6 +13,7 @@ function(c10d_add_test test_src)
if(NOT WIN32)
target_link_libraries(${test_name} pthread)
endif()
target_link_libraries(${test_name} fmt::fmt-header-only)
add_test(NAME ${test_name} COMMAND $<TARGET_FILE:${test_name}>)
endfunction()
@ -92,4 +93,5 @@ if(LINUX AND USE_GLOO AND USE_C10D_GLOO)
if(USE_CUDA)
target_link_libraries(example_allreduce torch_cuda)
endif()
target_link_libraries(example_allreduce fmt::fmt-header-only)
endif()

View File

@ -5,7 +5,7 @@ set(TORCH_RPC_TEST_SOURCES
${TORCH_RPC_TEST_DIR}/test_wire_serialization.cpp
)
set(TORCH_RPC_TEST_DEPENDENCY_LIBS
torch gtest
torch gtest fmt::fmt-header-only
)
if(USE_GLOO)

View File

@ -33,6 +33,7 @@ class TORCH_API Backend : public torch::CustomClassHolder {
std::chrono::milliseconds timeout;
// backend name
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const std::string backend;
};
@ -397,7 +398,9 @@ class TORCH_API Backend : public torch::CustomClassHolder {
// appropriate logging etc.
void init();
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const int rank_;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const int size_;
// Debug level setting. It is parsed once when ProcessGroup is constructed and
// remains the same across use of this process group.

View File

@ -59,10 +59,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
std::chrono::milliseconds timeout;
// backend name
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const std::string backend;
};
enum BackendType {
enum BackendType : uint8_t {
UNDEFINED = 0,
GLOO = 1,
NCCL = 2,
@ -719,9 +720,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
void init();
c10::intrusive_ptr<c10d::Store> store_;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const int rank_;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const int size_;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const c10::intrusive_ptr<Options> options_;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const BackendType backendType_;
std::string pg_desc_;

View File

@ -975,7 +975,8 @@ c10::intrusive_ptr<Work> ProcessGroupGloo::broadcast(
};
assertRootRank(invalidArgument, opts.rootRank, size_);
assertRootTensor(invalidArgument, opts.rootTensor, inputs.size());
assertRootTensor(
invalidArgument, opts.rootTensor, static_cast<int64_t>(inputs.size()));
assertDense(invalidArgument, inputs);
assertTypeAndSizesMatch(invalidArgument, inputs);
@ -1300,7 +1301,9 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork {
// Allgatherv indices.
gloo::AllgathervOptions opts(context);
opts.setInput(
const_cast<int64_t*>(input.const_data_ptr<int64_t>()), input.numel());
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<int64_t*>(input.const_data_ptr<int64_t>()),
input.numel());
opts.setOutput(output.mutable_data_ptr<int64_t>(), counts);
opts.setTag(tag);
gloo::allgatherv(opts);
@ -1308,7 +1311,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork {
// Compile indices tensor per rank.
std::vector<at::Tensor> indices;
indices.reserve(metadata.size());
size_t offset = 0;
int64_t offset = 0;
for (const auto& i : metadata) {
const auto nnz = i.nnz();
const auto numel = sparseDim * nnz;
@ -1325,7 +1328,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork {
const std::vector<SparseTensorMetadata>& metadata) {
// There are nnz #dense_dim()-dimensional tensors per rank.
const auto valueShape = tensor.sizes().slice(tensor.sparse_dim());
size_t denseNumel = 1;
int64_t denseNumel = 1;
for (auto dim : valueShape) {
denseNumel *= dim;
}
@ -1334,7 +1337,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork {
int64_t totalSize = 0;
for (const auto i : c10::irange(metadata.size())) {
counts[i] = metadata[i].nnz() * denseNumel;
totalSize += counts[i];
totalSize += static_cast<int64_t>(counts[i]);
}
auto output = at::empty({totalSize}, tensor.scalar_type());
@ -1353,7 +1356,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork {
// Compile values tensor per rank.
std::vector<at::Tensor> values;
values.reserve(metadata.size());
size_t offset = 0;
int64_t offset = 0;
for (const auto& i : metadata) {
const auto nnz = i.nnz();
const auto numel = denseNumel * nnz;
@ -1740,7 +1743,8 @@ c10::intrusive_ptr<Work> ProcessGroupGloo::reduce(
};
assertRootRank(invalidArgument, opts.rootRank, size_);
assertRootTensor(invalidArgument, opts.rootTensor, inputs.size());
assertRootTensor(
invalidArgument, opts.rootTensor, static_cast<int64_t>(inputs.size()));
assertSingleElement(invalidArgument, inputs);
assertDense(invalidArgument, inputs);
@ -1832,7 +1836,7 @@ class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork {
// Unflatten into output tensors.
for (auto& outputgroup : outputs) {
for (const auto j : c10::irange(outputgroup.size())) {
outputgroup[j].copy_(flatOutputTensor[j]);
outputgroup[j].copy_(flatOutputTensor[static_cast<int64_t>(j)]);
}
}
}
@ -2102,7 +2106,7 @@ class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork {
for (const auto& t : output_lists[0]) {
output_numel += t.numel();
}
output_numel *= output_lists.size();
output_numel *= static_cast<int64_t>(output_lists.size());
// Use single flat output tensor.
at::Tensor flatOutputTensor =
at::empty({output_numel}, output_lists[0][0].options());
@ -2251,7 +2255,7 @@ class AsyncGatherWork : public ProcessGroupGloo::AsyncWork {
// Unflatten into output tensors on root process.
if (context->rank == root) {
for (const auto i : c10::irange(outputs[0].size())) {
outputs[0][i].copy_(flatOutputTensor[i]);
outputs[0][i].copy_(flatOutputTensor[static_cast<int64_t>(i)]);
}
}
}
@ -2805,6 +2809,7 @@ c10::intrusive_ptr<Work> ProcessGroupGloo::send(
// Construct unbound buffer.
auto context = getContext(tag);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto buf = context->createUnboundBuffer(const_cast<void*>(ptr), size);
buf->send(dstRank, utag);
++seq_;
@ -2945,8 +2950,8 @@ void ProcessGroupGloo::monitoredBarrier(
// only enforce timeout on rank 0. This is so that other ranks aren't timed
// out first, bringing down the job without reporting which rank timed out.
if (rank != 0) {
auto sendWork = send(commTensor, 0, t1);
auto recvWork = recv(commTensor, 0, t2);
auto sendWork = send(commTensor, 0, static_cast<int>(t1));
auto recvWork = recv(commTensor, 0, static_cast<int>(t2));
try {
sendWork->wait();
recvWork->wait();
@ -2970,7 +2975,8 @@ void ProcessGroupGloo::monitoredBarrier(
// Failed/hanging ranks will not ack this call, letting rank 0 know about the
// failure.
for (const auto dstRank : c10::irange(1, worldSize)) {
recvWorkMap.insert({dstRank, recv(commTensor, dstRank, t1)});
recvWorkMap.emplace(
dstRank, recv(commTensor, dstRank, static_cast<int>(t1)));
}
auto waitLoop = [&](const std::map<int, c10::intrusive_ptr<Work>>& works) {
@ -3042,7 +3048,8 @@ void ProcessGroupGloo::monitoredBarrier(
// ensures that this is a true barrier in that all ranks exit it successfully
// or none of them do.
for (const auto dstRank : c10::irange(1, worldSize)) {
sendWorkMap.insert({dstRank, send(commTensor, dstRank, t2)});
sendWorkMap.emplace(
dstRank, send(commTensor, dstRank, static_cast<int>(t2)));
}
waitLoop(sendWorkMap);

View File

@ -514,7 +514,7 @@ c10::intrusive_ptr<Work> ProcessGroupMPI::allgather(
pgComm_));
for (const auto i : c10::irange(outputDataVec.size())) {
outputDataVec[i].copy_(flatOutputTensor[i]);
outputDataVec[i].copy_(flatOutputTensor[static_cast<int64_t>(i)]);
}
};
auto entry = std::make_unique<WorkEntry>(
@ -586,7 +586,8 @@ c10::intrusive_ptr<Work> ProcessGroupMPI::gather(
const std::vector<at::Tensor>& outputDataVec = entry->dst;
// copy the flattened output tensors to the outputs
for (const auto i : c10::irange(outputDataVec.size())) {
outputDataVec.at(i).copy_(flatOutputTensor[i]);
outputDataVec.at(i).copy_(
flatOutputTensor[static_cast<int64_t>(i)]);
}
}
};
@ -647,7 +648,7 @@ c10::intrusive_ptr<Work> ProcessGroupMPI::scatter(
// copy the input tensors to the flatten large send buffer
for (const auto i : c10::irange(inputDataVec.size())) {
flatInputTensor[i].copy_(inputDataVec.at(i));
flatInputTensor[static_cast<int64_t>(i)].copy_(inputDataVec.at(i));
}
}
@ -793,16 +794,18 @@ c10::intrusive_ptr<Work> ProcessGroupMPI::alltoall(
std::vector<int> recv_offsets(size_);
auto srcdata = entry->src;
auto dstdata = entry->dst;
int64_t src_len = c10d::computeLengthsAndOffsets(
auto src_len = c10d::computeLengthsAndOffsets(
srcdata, &send_lengths, &send_offsets);
int64_t dst_len = c10d::computeLengthsAndOffsets(
auto dst_len = c10d::computeLengthsAndOffsets(
dstdata, &recv_lengths, &recv_offsets);
std::vector<int64_t> send_lengthsL(
send_lengths.begin(), send_lengths.end());
std::vector<int64_t> recv_lengthsL(
recv_lengths.begin(), recv_lengths.end());
at::Tensor srcFlatData = at::empty({src_len}, srcdata[0].options());
at::Tensor dstFlatData = at::empty({dst_len}, dstdata[0].options());
at::Tensor srcFlatData =
at::empty({static_cast<int64_t>(src_len)}, srcdata[0].options());
at::Tensor dstFlatData =
at::empty({static_cast<int64_t>(dst_len)}, dstdata[0].options());
auto srcFlatDataSplits =
srcFlatData.split_with_sizes(c10::IntArrayRef(send_lengthsL), 0);
for (const auto i : c10::irange(size_)) {

View File

@ -31,12 +31,12 @@ struct CollectiveFingerPrint {
std::vector<int8_t> tensor_device_types_;
// input tensor sizes
std::vector<std::vector<int64_t>> tensor_sizes_;
int sequence_number_;
uint64_t sequence_number_;
CollectiveFingerPrint(
OpType op_type,
const std::vector<at::Tensor>& input_tensors,
int sequence_number)
uint64_t sequence_number)
: op_type_(op_type),
num_tensors_(input_tensors.size()),
sequence_number_(sequence_number) {
@ -57,7 +57,7 @@ struct CollectiveFingerPrint {
std::vector<int8_t> tensor_dtypes,
std::vector<int8_t> tensor_device_types,
std::vector<std::vector<int64_t>> tensor_sizes,
int sequence_number)
uint64_t sequence_number)
: op_type_(op_type),
num_tensors_(num_tensors),
tensor_dtypes_(std::move(tensor_dtypes)),
@ -296,7 +296,7 @@ struct CollectiveFingerPrint {
// 1. OpType
data->push_back(static_cast<int64_t>(op_type_));
// sequence number
data->push_back(sequence_number_);
data->push_back(static_cast<int64_t>(sequence_number_));
// 2. Num tensors
data->push_back(static_cast<int64_t>(num_tensors_));
// 3. Tensor dtypes
@ -309,13 +309,13 @@ struct CollectiveFingerPrint {
}
// 5. Shapes
for (const auto& sizes : tensor_sizes_) {
data->push_back(sizes.size());
data->push_back(static_cast<int64_t>(sizes.size()));
for (const auto& s : sizes) {
data->push_back(s);
}
}
// Serialize data into tensor
int64_t data_size = data->size();
int64_t data_size = static_cast<int64_t>(data->size());
// Need to release here and get the ptr due to C++ parameter evaluation
// order.
auto d = data.release();

View File

@ -207,7 +207,7 @@ class TORCH_PYTHON_API PythonOnCompletionHook {
hook_.ptr() = nullptr;
}
void operator()(std::shared_ptr<WorkInfo> workInfo) const {
void operator()(const std::shared_ptr<WorkInfo>& workInfo) const {
std::exception_ptr eptr;
{
py::gil_scoped_acquire acquire;

View File

@ -30,10 +30,10 @@ class Counter {
return count_;
}
double variance() const noexcept {
return m2_ / count_;
return m2_ / static_cast<double>(count_);
}
double sample_variance() const noexcept {
return m2_ / (count_ - 1);
return m2_ / static_cast<double>(count_ - 1);
}
private:

View File

@ -1,10 +1,6 @@
#include <torch/csrc/distributed/c10d/Utils.hpp>
#include <algorithm>
#include <cstring>
#include <memory>
#include <string>
#include <thread>
namespace c10d {

View File

@ -4,6 +4,7 @@
#include <c10/util/Exception.h>
#include <c10/util/accumulate.h>
#include <c10/util/irange.h>
#include <fmt/format.h>
#include <torch/csrc/distributed/c10d/Types.hpp>
#ifdef _WIN32
@ -66,7 +67,7 @@ inline void assertSameType(
const std::string expected = type.toString();
const std::string actual = tensors[i].toString();
throw std::invalid_argument(
"mixed types (" + expected + " and " + actual + ")");
fmt::format("mixed types ({} and {})", expected, actual));
}
}
}
@ -96,7 +97,7 @@ inline std::string getCvarString(
/* parse environment variable in reverse order, so the early
* versions of a variable get higher priority than the latter
* versions of the same variable */
for (int i = env.size() - 1; i >= 0; i--) {
for (ssize_t i = static_cast<ssize_t>(env.size()) - 1; i >= 0; i--) {
const char* val = std::getenv(env[i].c_str());
if (val == nullptr) {
continue;
@ -123,7 +124,7 @@ inline int getCvarInt(const std::vector<std::string>& env, int def) {
/* parse environment variable in reverse order, so the early
* versions of a variable get higher priority than the latter
* versions of the same variable */
for (int i = env.size() - 1; i >= 0; i--) {
for (ssize_t i = static_cast<ssize_t>(env.size()) - 1; i >= 0; i--) {
char* val = std::getenv(env[i].c_str());
if (val == nullptr) {
continue;
@ -154,7 +155,7 @@ inline bool getCvarBool(const std::vector<std::string>& env, bool def) {
/* parse environment variable in reverse order, so the early
* versions of a variable get higher priority than the latter
* versions of the same variable */
for (int i = env.size() - 1; i >= 0; i--) {
for (ssize_t i = static_cast<ssize_t>(env.size()) - 1; i >= 0; i--) {
char* val_ = std::getenv(env[i].c_str());
if (val_ == nullptr) {
continue;
@ -166,6 +167,7 @@ inline bool getCvarBool(const std::vector<std::string>& env, bool def) {
std::string val = std::string(val_);
for (auto& x : val) {
// NOLINTNEXTLINE(*-narrowing-conversions)
x = std::tolower(x);
}
@ -193,7 +195,7 @@ inline void assertSameSizes(
const auto expected = toString(sizes);
const auto actual = toString(tensors[i].sizes());
throw std::invalid_argument(
"mixed sizes (" + expected + " and " + actual + ")");
fmt::format("mixed sizes ({} and {})", expected, actual));
}
}
}
@ -211,22 +213,20 @@ inline void assertSameSizeAndType(const std::vector<at::Tensor>& tensors) {
if (!tensors[i].options().type_equal(options)) {
const auto expected = toString(options);
const auto actual = toString(tensors[i].options());
throw std::invalid_argument(
"argument contains mixed types (" + expected + " and " + actual +
")");
throw std::invalid_argument(fmt::format(
"argument contains mixed types ({} and {})", expected, actual));
}
if (!tensors[i].sizes().equals(sizes)) {
const auto expected = toString(sizes);
const auto actual = toString(tensors[i].sizes());
throw std::invalid_argument(
"argument contains mixed sizes (" + expected + " and " + actual +
")");
throw std::invalid_argument(fmt::format(
"argument contains mixed types ({} and {})", expected, actual));
}
}
}
inline void assertTypeMatch(
std::function<void(const std::string&)> fn,
const std::function<void(const std::string&)>& fn,
const at::DeprecatedTypeProperties& type,
const at::ArrayRef<at::Tensor> tensors,
size_t index) {
@ -237,7 +237,7 @@ inline void assertTypeMatch(
}
inline void assertTypeMatch(
std::function<void(const std::string&)> fn,
const std::function<void(const std::string&)>& fn,
const at::TensorOptions& options,
const at::ArrayRef<at::Tensor> tensors,
size_t index) {
@ -248,7 +248,7 @@ inline void assertTypeMatch(
}
inline void assertSizesMatch(
std::function<void(const std::string&)> fn,
const std::function<void(const std::string&)>& fn,
const at::IntArrayRef& sizes,
const at::ArrayRef<at::Tensor> tensors,
size_t index) {
@ -259,7 +259,7 @@ inline void assertSizesMatch(
}
inline void assertLayoutMatch(
std::function<void(const std::string&)> fn,
const std::function<void(const std::string&)>& fn,
const c10::Layout& expected,
const at::ArrayRef<at::Tensor> tensors,
size_t index) {
@ -271,7 +271,7 @@ inline void assertLayoutMatch(
}
inline void assertLayoutMatch(
std::function<void(const std::string&)> fn,
const std::function<void(const std::string&)>& fn,
const at::ArrayRef<at::Tensor> tensors) {
const auto& layout = tensors[0].layout();
for (const auto i : c10::irange(1, tensors.size())) {
@ -362,7 +362,7 @@ inline void assertSameDevice(
}
inline void assertTypeAndSizesMatch(
std::function<void(const std::string&)> fn,
const std::function<void(const std::string&)>& fn,
const at::ArrayRef<at::Tensor> tensors,
const at::DeprecatedTypeProperties& type,
const at::IntArrayRef& sizes) {
@ -373,7 +373,7 @@ inline void assertTypeAndSizesMatch(
}
inline void assertTypeAndSizesMatch(
std::function<void(const std::string&)> fn,
const std::function<void(const std::string&)>& fn,
const at::ArrayRef<at::Tensor> tensors,
const at::TensorOptions& options,
const at::IntArrayRef& sizes) {
@ -384,7 +384,7 @@ inline void assertTypeAndSizesMatch(
}
inline void assertTypeAndSizesMatch(
std::function<void(const std::string&)> fn,
const std::function<void(const std::string&)>& fn,
const at::ArrayRef<at::Tensor> tensors) {
const auto& options = tensors[0].options();
const auto sizes = tensors[0].sizes();
@ -463,6 +463,7 @@ inline std::vector<int> getDevices(const std::vector<at::Tensor>& tensors) {
std::vector<int> devices(tensors.size(), -1);
if (tensors[0].device().is_cuda()) {
for (const auto i : c10::irange(tensors.size())) {
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
devices[i] = tensors[i].storage().device().index();
}
}
@ -620,8 +621,7 @@ void sendBytes(
return;
}
auto bytes = reinterpret_cast<const uint8_t*>(buffer);
uint8_t* currentBytes = const_cast<uint8_t*>(bytes);
auto currentBytes = reinterpret_cast<const char*>(buffer);
int flags = 0;
@ -637,10 +637,9 @@ void sendBytes(
#endif
while (bytesToSend > 0) {
ssize_t bytesSent;
ssize_t bytesSent = 0;
SYSCHECK_ERR_RETURN_NEG1(
bytesSent =
::send(socket, (const char*)currentBytes, bytesToSend, flags))
bytesSent = ::send(socket, currentBytes, bytesToSend, flags))
if (bytesSent == 0) {
C10_THROW_ERROR(DistNetworkError, std::strerror(ECONNRESET));
}
@ -657,13 +656,12 @@ void recvBytes(int socket, T* buffer, size_t length) {
return;
}
auto bytes = reinterpret_cast<uint8_t*>(buffer);
uint8_t* currentBytes = bytes;
auto currentBytes = reinterpret_cast<char*>(buffer);
while (bytesToReceive > 0) {
ssize_t bytesReceived;
ssize_t bytesReceived = 0;
SYSCHECK_ERR_RETURN_NEG1(
bytesReceived = recv(socket, (char*)currentBytes, bytesToReceive, 0))
bytesReceived = recv(socket, currentBytes, bytesToReceive, 0))
if (bytesReceived == 0) {
C10_THROW_ERROR(DistNetworkError, std::strerror(ECONNRESET));
}
@ -684,7 +682,7 @@ void sendVector(int socket, const std::vector<T>& vec, bool moreData = false) {
// receive a vector as sent in sendVector
template <typename T>
std::vector<T> recvVector(int socket) {
SizeType valueSize;
SizeType valueSize = 0;
recvBytes<SizeType>(socket, &valueSize, 1);
std::vector<T> value(valueSize);
recvBytes<T>(socket, value.data(), value.size());
@ -716,7 +714,7 @@ inline void sendString(
// receive a string as sent in sendString
inline std::string recvString(int socket) {
SizeType valueSize;
SizeType valueSize = 0;
recvBytes<SizeType>(socket, &valueSize, 1);
std::vector<char> value(valueSize);
recvBytes<char>(socket, value.data(), value.size());

View File

@ -87,6 +87,7 @@ class TORCH_API GradBucket {
std::vector<c10::IntArrayRef> sizes_vec_;
// Model parameters for this bucket.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const std::vector<at::Tensor> parameters_;
// Predefined sparse indices for this bucket (only used for sparse tensors).

View File

@ -5,7 +5,7 @@
namespace c10d {
enum class BuiltinCommHookType {
enum class BuiltinCommHookType : uint8_t {
ALLREDUCE = 1,
FP16_COMPRESS = 2,
};

View File

@ -14,9 +14,18 @@ constexpr size_t kDefaultBufferSize = 10ull * 1024 * 1024;
using NvlMesh = std::array<std::array<size_t, kMaxDevices>, kMaxDevices>;
using HybridCubeMesh = std::array<std::array<int, 4>, kMaxDevices>;
enum class Topology { UNKNOWN = 0, FULLY_CONNECTED = 1, HYBRID_CUBE_MESH = 2 };
enum class Topology : uint8_t {
UNKNOWN = 0,
FULLY_CONNECTED = 1,
HYBRID_CUBE_MESH = 2
};
enum class AllReduceAlgo { NONE = 0, ONE_SHOT = 1, TWO_SHOT = 2, HCM = 3 };
enum class AllReduceAlgo : uint8_t {
NONE = 0,
ONE_SHOT = 1,
TWO_SHOT = 2,
HCM = 3
};
class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target {
public:

View File

@ -23,7 +23,7 @@ class TORCH_API Timer {
int64_t backward_comm_end_time = kUnsetTime;
public:
enum class Event {
enum class Event : uint8_t {
kForwardStart,
kBackwardComputeStart,
kBackwardComputeEnd,