mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Distributed/Profiler] Fix input/output dimension overflow (#134360)
Summary: When using ParamCommsDebugInfo, the input elements and output elements are stored in `int` instead of `int64_t` Test Plan: Run HTA with new outputted values and make sure overflow does not occur Reviewed By: fengxizhou Differential Revision: D61728747 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134360 Approved by: https://github.com/fengxizhou, https://github.com/jeanschmidt
This commit is contained in:
parent
e93ca12c88
commit
816061843a
|
|
@ -11,8 +11,8 @@ ParamCommsDebugInfo::ParamCommsDebugInfo(
|
|||
std::tuple<std::string, std::string> pgName,
|
||||
int rank,
|
||||
std::string&& collName,
|
||||
int inNelems,
|
||||
int outNelems,
|
||||
int64_t inNelems,
|
||||
int64_t outNelems,
|
||||
at::ScalarType dType,
|
||||
std::vector<int64_t> inSplitSizes,
|
||||
std::vector<int64_t> outSplitSizes,
|
||||
|
|
|
|||
|
|
@ -16,8 +16,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
|
|||
std::tuple<std::string, std::string> pgName,
|
||||
int rank,
|
||||
std::string&& collName,
|
||||
int inNelems,
|
||||
int outNelems,
|
||||
int64_t inNelems,
|
||||
int64_t outNelems,
|
||||
at::ScalarType dType,
|
||||
std::vector<int64_t> inSplitSizes,
|
||||
std::vector<int64_t> outSplitSizes,
|
||||
|
|
@ -55,11 +55,11 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
|
|||
return collectiveName_;
|
||||
}
|
||||
|
||||
int getInMessageNelems() const {
|
||||
int64_t getInMessageNelems() const {
|
||||
return inMessageNelems_;
|
||||
}
|
||||
|
||||
int getOutMessageNelems() const {
|
||||
int64_t getOutMessageNelems() const {
|
||||
return outMessageNelems_;
|
||||
}
|
||||
|
||||
|
|
@ -84,8 +84,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
|
|||
int rank_{};
|
||||
int worldSize_{};
|
||||
std::string collectiveName_;
|
||||
int inMessageNelems_{};
|
||||
int outMessageNelems_{};
|
||||
int64_t inMessageNelems_{};
|
||||
int64_t outMessageNelems_{};
|
||||
at::ScalarType dType_ = at::kByte;
|
||||
std::vector<int64_t> inputSplitSizes_;
|
||||
std::vector<int64_t> outputSplitSizes_;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user