mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[PG Wrapper][BE] Make some methods private (#66166)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66166 These methods should be private. ghstack-source-id: 139782587 Test Plan: CI Reviewed By: cbalioglu Differential Revision: D31353020 fbshipit-source-id: 583fb315cc2cacc37df3d29cd5793b42558930b3
This commit is contained in:
parent
0cad2c0615
commit
b5b1d49a66
|
|
@ -46,6 +46,55 @@ struct CollectiveFingerPrint {
|
||||||
std::ostream& output,
|
std::ostream& output,
|
||||||
const CollectiveFingerPrint& collective_fingerprint);
|
const CollectiveFingerPrint& collective_fingerprint);
|
||||||
|
|
||||||
|
// Executes and verifies the collective fingerprint.
|
||||||
|
void verify(c10::intrusive_ptr<ProcessGroup> pg) {
|
||||||
|
at::Tensor serialized_tensor = serialize_fingerprint();
|
||||||
|
std::vector<at::Tensor> inp{serialized_tensor};
|
||||||
|
// First verify tensor shapes. This is needed because if e.g. tensor dim
|
||||||
|
// does not match across processes, directly verifying tensors will result
|
||||||
|
// in a crash during allgather, but we'd actually like to report a
|
||||||
|
// description about the inconsistency. Since the input is just a 1D tensor
|
||||||
|
// the shape will be a single int k_i and we need to make sure k_i is
|
||||||
|
// consistent across the whole world.
|
||||||
|
std::vector<at::Tensor> sp = c10d::getTensorShapes(inp);
|
||||||
|
verify_tensors(sp, pg);
|
||||||
|
// Now verify consistency for the actual tensor.
|
||||||
|
verify_tensors(inp, pg);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void verify_tensors(
|
||||||
|
std::vector<at::Tensor>& tensors_to_verify,
|
||||||
|
c10::intrusive_ptr<ProcessGroup>& pg) {
|
||||||
|
// Create output tensor data structure to pass into allgather.
|
||||||
|
std::vector<std::vector<at::Tensor>> output_tensors;
|
||||||
|
output_tensors.reserve(tensors_to_verify.size());
|
||||||
|
for (auto& tensor_shape : tensors_to_verify) {
|
||||||
|
std::vector<at::Tensor> outputs;
|
||||||
|
outputs.reserve(pg->getSize());
|
||||||
|
for (int i = 0; i < pg->getSize(); ++i) {
|
||||||
|
outputs.emplace_back(at::zeros_like(tensor_shape));
|
||||||
|
}
|
||||||
|
output_tensors.emplace_back(outputs);
|
||||||
|
}
|
||||||
|
// Allgather tensor shapes.
|
||||||
|
pg->allgather(output_tensors, tensors_to_verify)->wait();
|
||||||
|
// Verify equivalence
|
||||||
|
for (const auto i : c10::irange(output_tensors.size())) {
|
||||||
|
const std::vector<at::Tensor> gathered_tensors = output_tensors[i];
|
||||||
|
const at::Tensor reference_tensor = tensors_to_verify[i];
|
||||||
|
for (const auto& rank_tensor : gathered_tensors) {
|
||||||
|
if (!rank_tensor.equal(reference_tensor)) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "Detected mismatch between collectives on ranks. Rank "
|
||||||
|
<< pg->getRank()
|
||||||
|
<< " is running inconsistent collective: " << *this;
|
||||||
|
TORCH_CHECK(false, ss.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
at::Tensor serialize_fingerprint() {
|
at::Tensor serialize_fingerprint() {
|
||||||
auto data = std::make_unique<std::vector<int64_t>>();
|
auto data = std::make_unique<std::vector<int64_t>>();
|
||||||
// std::vector<int64_t> data;
|
// std::vector<int64_t> data;
|
||||||
|
|
@ -82,54 +131,6 @@ struct CollectiveFingerPrint {
|
||||||
.make_tensor();
|
.make_tensor();
|
||||||
return serialized_tensor;
|
return serialized_tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
void verify_tensors(
|
|
||||||
std::vector<at::Tensor>& tensors_to_verify,
|
|
||||||
c10::intrusive_ptr<ProcessGroup>& pg) {
|
|
||||||
// Create output tensor data structure to pass into allgather.
|
|
||||||
std::vector<std::vector<at::Tensor>> output_tensors;
|
|
||||||
output_tensors.reserve(tensors_to_verify.size());
|
|
||||||
for (auto& tensor_shape : tensors_to_verify) {
|
|
||||||
std::vector<at::Tensor> outputs;
|
|
||||||
outputs.reserve(pg->getSize());
|
|
||||||
for (int i = 0; i < pg->getSize(); ++i) {
|
|
||||||
outputs.emplace_back(at::zeros_like(tensor_shape));
|
|
||||||
}
|
|
||||||
output_tensors.emplace_back(outputs);
|
|
||||||
}
|
|
||||||
// Allgather tensor shapes.
|
|
||||||
pg->allgather(output_tensors, tensors_to_verify)->wait();
|
|
||||||
// Verify equivalence
|
|
||||||
for (const auto i : c10::irange(output_tensors.size())) {
|
|
||||||
const std::vector<at::Tensor> gathered_tensors = output_tensors[i];
|
|
||||||
const at::Tensor reference_tensor = tensors_to_verify[i];
|
|
||||||
for (const auto& rank_tensor : gathered_tensors) {
|
|
||||||
if (!rank_tensor.equal(reference_tensor)) {
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << "Detected mismatch between collectives on ranks. Rank "
|
|
||||||
<< pg->getRank()
|
|
||||||
<< " is running inconsistent collective: " << *this;
|
|
||||||
TORCH_CHECK(false, ss.str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Executes and verifies the collective fingerprint.
|
|
||||||
void verify(c10::intrusive_ptr<ProcessGroup> pg) {
|
|
||||||
at::Tensor serialized_tensor = serialize_fingerprint();
|
|
||||||
std::vector<at::Tensor> inp{serialized_tensor};
|
|
||||||
// First verify tensor shapes. This is needed because if e.g. tensor dim
|
|
||||||
// does not match across processes, directly verifying tensors will result
|
|
||||||
// in a crash during allgather, but we'd actually like to report a
|
|
||||||
// description about the inconsistency. Since the input is just a 1D tensor
|
|
||||||
// the shape will be a single int k_i and we need to make sure k_i is
|
|
||||||
// consistent across the whole world.
|
|
||||||
std::vector<at::Tensor> sp = c10d::getTensorShapes(inp);
|
|
||||||
verify_tensors(sp, pg);
|
|
||||||
// Now verify consistency for the actual tensor.
|
|
||||||
verify_tensors(inp, pg);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
std::ostream& operator<<(
|
std::ostream& operator<<(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user