[Distributed/Profiler] Fix input/output dimension overflow (#134360)

Summary: When using ParamCommsDebugInfo, the input elements and output elements are stored in `int` instead of `int64_t`

Test Plan: Run HTA with new outputted values and make sure overflow does not occur

Reviewed By: fengxizhou

Differential Revision: D61728747

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134360
Approved by: https://github.com/fengxizhou, https://github.com/jeanschmidt
This commit is contained in:
Shivam Raikundalia 2024-08-25 16:25:56 +00:00 committed by PyTorch MergeBot
parent e93ca12c88
commit 816061843a
2 changed files with 8 additions and 8 deletions

View File

@ -11,8 +11,8 @@ ParamCommsDebugInfo::ParamCommsDebugInfo(
std::tuple<std::string, std::string> pgName,
int rank,
std::string&& collName,
int inNelems,
int outNelems,
int64_t inNelems,
int64_t outNelems,
at::ScalarType dType,
std::vector<int64_t> inSplitSizes,
std::vector<int64_t> outSplitSizes,

View File

@ -16,8 +16,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
std::tuple<std::string, std::string> pgName,
int rank,
std::string&& collName,
int inNelems,
int outNelems,
int64_t inNelems,
int64_t outNelems,
at::ScalarType dType,
std::vector<int64_t> inSplitSizes,
std::vector<int64_t> outSplitSizes,
@ -55,11 +55,11 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
return collectiveName_;
}
int getInMessageNelems() const {
int64_t getInMessageNelems() const {
return inMessageNelems_;
}
int getOutMessageNelems() const {
int64_t getOutMessageNelems() const {
return outMessageNelems_;
}
@ -84,8 +84,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
int rank_{};
int worldSize_{};
std::string collectiveName_;
int inMessageNelems_{};
int outMessageNelems_{};
int64_t inMessageNelems_{};
int64_t outMessageNelems_{};
at::ScalarType dType_ = at::kByte;
std::vector<int64_t> inputSplitSizes_;
std::vector<int64_t> outputSplitSizes_;