[PG Wrapper][BE] Make some methods private (#66166)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66166

These methods should be private.
ghstack-source-id: 139782587

Test Plan: CI

Reviewed By: cbalioglu

Differential Revision: D31353020

fbshipit-source-id: 583fb315cc2cacc37df3d29cd5793b42558930b3
This commit is contained in:
Rohan Varma 2021-10-08 09:07:11 -07:00 committed by Facebook GitHub Bot
parent 0cad2c0615
commit b5b1d49a66

View File

@ -46,6 +46,55 @@ struct CollectiveFingerPrint {
std::ostream& output, std::ostream& output,
const CollectiveFingerPrint& collective_fingerprint); const CollectiveFingerPrint& collective_fingerprint);
// Executes and verifies the collective fingerprint.
void verify(c10::intrusive_ptr<ProcessGroup> pg) {
at::Tensor serialized_tensor = serialize_fingerprint();
std::vector<at::Tensor> inp{serialized_tensor};
// First verify tensor shapes. This is needed because if e.g. tensor dim
// does not match across processes, directly verifying tensors will result
// in a crash during allgather, but we'd actually like to report a
// description about the inconsistency. Since the input is just a 1D tensor
// the shape will be a single int k_i and we need to make sure k_i is
// consistent across the whole world.
std::vector<at::Tensor> sp = c10d::getTensorShapes(inp);
verify_tensors(sp, pg);
// Now verify consistency for the actual tensor.
verify_tensors(inp, pg);
}
private:
void verify_tensors(
std::vector<at::Tensor>& tensors_to_verify,
c10::intrusive_ptr<ProcessGroup>& pg) {
// Create output tensor data structure to pass into allgather.
std::vector<std::vector<at::Tensor>> output_tensors;
output_tensors.reserve(tensors_to_verify.size());
for (auto& tensor_shape : tensors_to_verify) {
std::vector<at::Tensor> outputs;
outputs.reserve(pg->getSize());
for (int i = 0; i < pg->getSize(); ++i) {
outputs.emplace_back(at::zeros_like(tensor_shape));
}
output_tensors.emplace_back(outputs);
}
// Allgather tensor shapes.
pg->allgather(output_tensors, tensors_to_verify)->wait();
// Verify equivalence
for (const auto i : c10::irange(output_tensors.size())) {
const std::vector<at::Tensor> gathered_tensors = output_tensors[i];
const at::Tensor reference_tensor = tensors_to_verify[i];
for (const auto& rank_tensor : gathered_tensors) {
if (!rank_tensor.equal(reference_tensor)) {
std::stringstream ss;
ss << "Detected mismatch between collectives on ranks. Rank "
<< pg->getRank()
<< " is running inconsistent collective: " << *this;
TORCH_CHECK(false, ss.str());
}
}
}
}
at::Tensor serialize_fingerprint() { at::Tensor serialize_fingerprint() {
auto data = std::make_unique<std::vector<int64_t>>(); auto data = std::make_unique<std::vector<int64_t>>();
// std::vector<int64_t> data; // std::vector<int64_t> data;
@ -82,54 +131,6 @@ struct CollectiveFingerPrint {
.make_tensor(); .make_tensor();
return serialized_tensor; return serialized_tensor;
} }
void verify_tensors(
std::vector<at::Tensor>& tensors_to_verify,
c10::intrusive_ptr<ProcessGroup>& pg) {
// Create output tensor data structure to pass into allgather.
std::vector<std::vector<at::Tensor>> output_tensors;
output_tensors.reserve(tensors_to_verify.size());
for (auto& tensor_shape : tensors_to_verify) {
std::vector<at::Tensor> outputs;
outputs.reserve(pg->getSize());
for (int i = 0; i < pg->getSize(); ++i) {
outputs.emplace_back(at::zeros_like(tensor_shape));
}
output_tensors.emplace_back(outputs);
}
// Allgather tensor shapes.
pg->allgather(output_tensors, tensors_to_verify)->wait();
// Verify equivalence
for (const auto i : c10::irange(output_tensors.size())) {
const std::vector<at::Tensor> gathered_tensors = output_tensors[i];
const at::Tensor reference_tensor = tensors_to_verify[i];
for (const auto& rank_tensor : gathered_tensors) {
if (!rank_tensor.equal(reference_tensor)) {
std::stringstream ss;
ss << "Detected mismatch between collectives on ranks. Rank "
<< pg->getRank()
<< " is running inconsistent collective: " << *this;
TORCH_CHECK(false, ss.str());
}
}
}
}
// Executes and verifies the collective fingerprint.
void verify(c10::intrusive_ptr<ProcessGroup> pg) {
at::Tensor serialized_tensor = serialize_fingerprint();
std::vector<at::Tensor> inp{serialized_tensor};
// First verify tensor shapes. This is needed because if e.g. tensor dim
// does not match across processes, directly verifying tensors will result
// in a crash during allgather, but we'd actually like to report a
// description about the inconsistency. Since the input is just a 1D tensor
// the shape will be a single int k_i and we need to make sure k_i is
// consistent across the whole world.
std::vector<at::Tensor> sp = c10d::getTensorShapes(inp);
verify_tensors(sp, pg);
// Now verify consistency for the actual tensor.
verify_tensors(inp, pg);
}
}; };
std::ostream& operator<<( std::ostream& operator<<(