mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[PG Wrapper][BE] Make some methods private (#66166)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66166 These methods should be private. ghstack-source-id: 139782587 Test Plan: CI Reviewed By: cbalioglu Differential Revision: D31353020 fbshipit-source-id: 583fb315cc2cacc37df3d29cd5793b42558930b3
This commit is contained in:
parent
0cad2c0615
commit
b5b1d49a66
|
|
@ -46,6 +46,55 @@ struct CollectiveFingerPrint {
|
|||
std::ostream& output,
|
||||
const CollectiveFingerPrint& collective_fingerprint);
|
||||
|
||||
// Executes and verifies the collective fingerprint.
|
||||
void verify(c10::intrusive_ptr<ProcessGroup> pg) {
|
||||
at::Tensor serialized_tensor = serialize_fingerprint();
|
||||
std::vector<at::Tensor> inp{serialized_tensor};
|
||||
// First verify tensor shapes. This is needed because if e.g. tensor dim
|
||||
// does not match across processes, directly verifying tensors will result
|
||||
// in a crash during allgather, but we'd actually like to report a
|
||||
// description about the inconsistency. Since the input is just a 1D tensor
|
||||
// the shape will be a single int k_i and we need to make sure k_i is
|
||||
// consistent across the whole world.
|
||||
std::vector<at::Tensor> sp = c10d::getTensorShapes(inp);
|
||||
verify_tensors(sp, pg);
|
||||
// Now verify consistency for the actual tensor.
|
||||
verify_tensors(inp, pg);
|
||||
}
|
||||
|
||||
private:
|
||||
void verify_tensors(
|
||||
std::vector<at::Tensor>& tensors_to_verify,
|
||||
c10::intrusive_ptr<ProcessGroup>& pg) {
|
||||
// Create output tensor data structure to pass into allgather.
|
||||
std::vector<std::vector<at::Tensor>> output_tensors;
|
||||
output_tensors.reserve(tensors_to_verify.size());
|
||||
for (auto& tensor_shape : tensors_to_verify) {
|
||||
std::vector<at::Tensor> outputs;
|
||||
outputs.reserve(pg->getSize());
|
||||
for (int i = 0; i < pg->getSize(); ++i) {
|
||||
outputs.emplace_back(at::zeros_like(tensor_shape));
|
||||
}
|
||||
output_tensors.emplace_back(outputs);
|
||||
}
|
||||
// Allgather tensor shapes.
|
||||
pg->allgather(output_tensors, tensors_to_verify)->wait();
|
||||
// Verify equivalence
|
||||
for (const auto i : c10::irange(output_tensors.size())) {
|
||||
const std::vector<at::Tensor> gathered_tensors = output_tensors[i];
|
||||
const at::Tensor reference_tensor = tensors_to_verify[i];
|
||||
for (const auto& rank_tensor : gathered_tensors) {
|
||||
if (!rank_tensor.equal(reference_tensor)) {
|
||||
std::stringstream ss;
|
||||
ss << "Detected mismatch between collectives on ranks. Rank "
|
||||
<< pg->getRank()
|
||||
<< " is running inconsistent collective: " << *this;
|
||||
TORCH_CHECK(false, ss.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor serialize_fingerprint() {
|
||||
auto data = std::make_unique<std::vector<int64_t>>();
|
||||
// std::vector<int64_t> data;
|
||||
|
|
@ -82,54 +131,6 @@ struct CollectiveFingerPrint {
|
|||
.make_tensor();
|
||||
return serialized_tensor;
|
||||
}
|
||||
|
||||
void verify_tensors(
|
||||
std::vector<at::Tensor>& tensors_to_verify,
|
||||
c10::intrusive_ptr<ProcessGroup>& pg) {
|
||||
// Create output tensor data structure to pass into allgather.
|
||||
std::vector<std::vector<at::Tensor>> output_tensors;
|
||||
output_tensors.reserve(tensors_to_verify.size());
|
||||
for (auto& tensor_shape : tensors_to_verify) {
|
||||
std::vector<at::Tensor> outputs;
|
||||
outputs.reserve(pg->getSize());
|
||||
for (int i = 0; i < pg->getSize(); ++i) {
|
||||
outputs.emplace_back(at::zeros_like(tensor_shape));
|
||||
}
|
||||
output_tensors.emplace_back(outputs);
|
||||
}
|
||||
// Allgather tensor shapes.
|
||||
pg->allgather(output_tensors, tensors_to_verify)->wait();
|
||||
// Verify equivalence
|
||||
for (const auto i : c10::irange(output_tensors.size())) {
|
||||
const std::vector<at::Tensor> gathered_tensors = output_tensors[i];
|
||||
const at::Tensor reference_tensor = tensors_to_verify[i];
|
||||
for (const auto& rank_tensor : gathered_tensors) {
|
||||
if (!rank_tensor.equal(reference_tensor)) {
|
||||
std::stringstream ss;
|
||||
ss << "Detected mismatch between collectives on ranks. Rank "
|
||||
<< pg->getRank()
|
||||
<< " is running inconsistent collective: " << *this;
|
||||
TORCH_CHECK(false, ss.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Executes and verifies the collective fingerprint.
|
||||
void verify(c10::intrusive_ptr<ProcessGroup> pg) {
|
||||
at::Tensor serialized_tensor = serialize_fingerprint();
|
||||
std::vector<at::Tensor> inp{serialized_tensor};
|
||||
// First verify tensor shapes. This is needed because if e.g. tensor dim
|
||||
// does not match across processes, directly verifying tensors will result
|
||||
// in a crash during allgather, but we'd actually like to report a
|
||||
// description about the inconsistency. Since the input is just a 1D tensor
|
||||
// the shape will be a single int k_i and we need to make sure k_i is
|
||||
// consistent across the whole world.
|
||||
std::vector<at::Tensor> sp = c10d::getTensorShapes(inp);
|
||||
verify_tensors(sp, pg);
|
||||
// Now verify consistency for the actual tensor.
|
||||
verify_tensors(inp, pg);
|
||||
}
|
||||
};
|
||||
|
||||
std::ostream& operator<<(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user