Skip to content

Commit

Permalink
[Distributed/Profiler] Fix input/output dimension overflow (pytorch#1…
Browse files Browse the repository at this point in the history
…34360)

Summary: When using ParamCommsDebugInfo, the input elements and output elements are stored in `int` instead of `int64_t`

Test Plan: Run HTA with new outputted values and make sure overflow does not occur

Reviewed By: fengxizhou

Differential Revision: D61728747

Pull Request resolved: pytorch#134360
Approved by: https://github.com/fengxizhou, https://github.com/jeanschmidt
  • Loading branch information
sraikund16 authored and pytorchmergebot committed Aug 25, 2024
1 parent e93ca12 commit 8160618
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions torch/csrc/distributed/c10d/ParamCommsUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ ParamCommsDebugInfo::ParamCommsDebugInfo(
std::tuple<std::string, std::string> pgName,
int rank,
std::string&& collName,
int inNelems,
int outNelems,
int64_t inNelems,
int64_t outNelems,
at::ScalarType dType,
std::vector<int64_t> inSplitSizes,
std::vector<int64_t> outSplitSizes,
Expand Down
12 changes: 6 additions & 6 deletions torch/csrc/distributed/c10d/ParamCommsUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
std::tuple<std::string, std::string> pgName,
int rank,
std::string&& collName,
int inNelems,
int outNelems,
int64_t inNelems,
int64_t outNelems,
at::ScalarType dType,
std::vector<int64_t> inSplitSizes,
std::vector<int64_t> outSplitSizes,
Expand Down Expand Up @@ -55,11 +55,11 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
return collectiveName_;
}

int getInMessageNelems() const {
int64_t getInMessageNelems() const {
return inMessageNelems_;
}

int getOutMessageNelems() const {
int64_t getOutMessageNelems() const {
return outMessageNelems_;
}

Expand All @@ -84,8 +84,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
int rank_{};
int worldSize_{};
std::string collectiveName_;
int inMessageNelems_{};
int outMessageNelems_{};
int64_t inMessageNelems_{};
int64_t outMessageNelems_{};
at::ScalarType dType_ = at::kByte;
std::vector<int64_t> inputSplitSizes_;
std::vector<int64_t> outputSplitSizes_;
Expand Down

0 comments on commit 8160618

Please sign in to comment.