mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59563 ghstack-source-id: 131331264 Test Plan: CI Reviewed By: malfet Differential Revision: D28932239 fbshipit-source-id: 5df6cdfa5253b15cbbc97039fe672d6d97321e34
81 lines
1.7 KiB
C++
81 lines
1.7 KiB
C++
#pragma once
|
|
|
|
#include <string>
|
|
#include <vector>
|
|
#include <c10/macros/Macros.h>
|
|
#include <c10/util/ThreadLocalDebugInfo.h>
|
|
#include <ATen/core/ivalue.h>
|
|
|
|
namespace torch {
|
|
|
|
extern TORCH_API const std::string kParamCommsCallName;
|
|
|
|
class TORCH_API ParamCommsDebugInfo
|
|
: public c10::DebugInfoBase {
|
|
|
|
public:
|
|
ParamCommsDebugInfo() = default;
|
|
ParamCommsDebugInfo(
|
|
int rank,
|
|
std::string&& colName,
|
|
int inSize,
|
|
int outSize,
|
|
at::ScalarType dType,
|
|
std::vector<int64_t> inSplitSizes,
|
|
std::vector<int64_t> outSplitSizes);
|
|
|
|
~ParamCommsDebugInfo() override = default;
|
|
|
|
int getRank() const {
|
|
return rank_;
|
|
}
|
|
|
|
const std::string getColumnName() const {
|
|
return columnName_;
|
|
}
|
|
|
|
int getInMessageSize() const {
|
|
return inMessageSize_;
|
|
}
|
|
|
|
int getOutMessageSize() const {
|
|
return outMessageSize_;
|
|
}
|
|
|
|
at::ScalarType getDType() const {
|
|
return dType_;
|
|
}
|
|
|
|
const std::vector<int64_t>& getInputSplitSizes() const {
|
|
return inputSplitSizes_;
|
|
}
|
|
|
|
const std::vector<int64_t>& getOutputSplitSizes() const {
|
|
return outputSplitSizes_;
|
|
}
|
|
|
|
private:
|
|
int rank_{};
|
|
std::string columnName_;
|
|
int inMessageSize_{};
|
|
int outMessageSize_{};
|
|
at::ScalarType dType_ = at::kByte;
|
|
std::vector<int64_t> inputSplitSizes_;
|
|
std::vector<int64_t> outputSplitSizes_;
|
|
};
|
|
|
|
|
|
#define RECORD_PARAM_COMMS(rank, colName, inSize, outSize, dType, inSplitSizes, outSplitSizes) \
|
|
auto paramCommsInfo = std::make_shared<torch::ParamCommsDebugInfo>( \
|
|
rank, \
|
|
colName, \
|
|
inSize, \
|
|
outSize, \
|
|
dType, \
|
|
inSplitSizes, \
|
|
outSplitSizes); \
|
|
c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \
|
|
RECORD_FUNCTION(torch::kParamCommsCallName, std::vector<c10::IValue>());
|
|
|
|
} // namespace torch
|