mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Per title Pull Request resolved: https://github.com/pytorch/pytorch/pull/166181 Approved by: https://github.com/kwen2501
This commit is contained in:
parent
761f946043
commit
2efcf3ca98
|
|
@ -3817,27 +3817,6 @@ class NcclProcessGroupWithDispatchedCollectivesTests(
|
|||
dist.all_gather_into_tensor(output_tensor, tensor)
|
||||
self.assertEqual(output_tensor, tensor)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_allgather_noncontig(self):
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
dist.init_process_group(
|
||||
"nccl",
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
)
|
||||
device = "cuda"
|
||||
tensor = (
|
||||
torch.arange(0, 16, device=torch.device(device))
|
||||
.view(2, 2, 2, 2)
|
||||
.to(memory_format=torch.channels_last)
|
||||
)
|
||||
tensor_list = [torch.empty_like(tensor) for _ in range(self.world_size)]
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
for o in tensor_list:
|
||||
self.assertEqual(o, tensor)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(1)
|
||||
@parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
|
||||
|
|
|
|||
|
|
@ -1381,8 +1381,7 @@ class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork {
|
|||
// Use single flat output tensor.
|
||||
// The first dimension corresponds to the index into outputs[N],
|
||||
// so copying into the actual output later is easy.
|
||||
at::Tensor flatOutputTensor =
|
||||
newLikeFlat(outputs[0], /*preserve_strides*/ false);
|
||||
at::Tensor flatOutputTensor = newLikeFlat(outputs[0]);
|
||||
GENERATE_ALL_TYPES(scalarType, setOutput, opts, flatOutputTensor);
|
||||
gloo::allgather(opts);
|
||||
|
||||
|
|
@ -1399,7 +1398,7 @@ class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork {
|
|||
}
|
||||
|
||||
const std::vector<at::Tensor> getOutputTensors() override {
|
||||
return {newLikeFlat(outputs[0], /*preserve_strides*/ false)};
|
||||
return {newLikeFlat(outputs[0])};
|
||||
}
|
||||
|
||||
void run() override {
|
||||
|
|
@ -1695,7 +1694,7 @@ class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork {
|
|||
}
|
||||
|
||||
const std::vector<at::Tensor> getOutputTensors() override {
|
||||
return {newLikeFlat(output_lists[0], /*preserve_strides*/ false)};
|
||||
return {newLikeFlat(output_lists[0])};
|
||||
}
|
||||
|
||||
void run() override {
|
||||
|
|
@ -1819,7 +1818,7 @@ class AsyncGatherWork : public ProcessGroupGloo::AsyncWork {
|
|||
// This is later scattered to the separate output tensors.
|
||||
at::Tensor flatOutputTensor;
|
||||
if (context_->rank == root) {
|
||||
flatOutputTensor = newLikeFlat(outputs[0], /*preserve_strides*/ false);
|
||||
flatOutputTensor = newLikeFlat(outputs[0]);
|
||||
GENERATE_ALL_TYPES(scalarType, setOutput, opts, flatOutputTensor);
|
||||
}
|
||||
|
||||
|
|
@ -1842,8 +1841,7 @@ class AsyncGatherWork : public ProcessGroupGloo::AsyncWork {
|
|||
|
||||
const std::vector<at::Tensor> getOutputTensors() override {
|
||||
return outputs.empty() ? std::vector<at::Tensor>{}
|
||||
: std::vector<at::Tensor>{newLikeFlat(
|
||||
outputs[0], /*preserve_strides*/ false)};
|
||||
: std::vector<at::Tensor>{newLikeFlat(outputs[0])};
|
||||
}
|
||||
|
||||
void run() override {
|
||||
|
|
@ -2059,8 +2057,7 @@ class AsyncScatterWork : public ProcessGroupGloo::AsyncWork {
|
|||
|
||||
const std::vector<at::Tensor> getInputTensors() override {
|
||||
return inputs.empty() ? std::vector<at::Tensor>{}
|
||||
: std::vector<at::Tensor>{newLikeFlat(
|
||||
inputs[0], /*preserve_strides*/ false)};
|
||||
: std::vector<at::Tensor>{newLikeFlat(inputs[0])};
|
||||
}
|
||||
|
||||
const std::vector<at::Tensor> getOutputTensors() override {
|
||||
|
|
|
|||
|
|
@ -4770,6 +4770,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
|
|||
bool same_size = check_same_size(outputTensors_);
|
||||
if (same_size) {
|
||||
// Flatten a vector of tensors into a single, stacked tensor.
|
||||
// we can handle only contiguous inputs, because we are
|
||||
// just sending ptr and numel to nccl
|
||||
inputTensor = inputTensor.contiguous();
|
||||
at::Tensor outputFlattened = newLikeFlat(outputTensors_);
|
||||
|
||||
return collective(
|
||||
|
|
@ -4917,6 +4920,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
|
|||
bool same_size = check_same_size(inputTensors_);
|
||||
if (same_size) {
|
||||
// Flatten a vector of tensors into a single, stacked tensor.
|
||||
outputTensor = outputTensor.contiguous();
|
||||
at::Tensor inputFlattened = newLikeFlat(inputTensors_);
|
||||
|
||||
return collective(
|
||||
|
|
|
|||
|
|
@ -444,9 +444,7 @@ inline at::Tensor newLikeFlat(
|
|||
sizes, strides, t.options().memory_format(std::nullopt));
|
||||
}
|
||||
|
||||
inline at::Tensor newLikeFlat(
|
||||
std::vector<at::Tensor>& tensors,
|
||||
bool preserve_strides = true) {
|
||||
inline at::Tensor newLikeFlat(std::vector<at::Tensor>& tensors) {
|
||||
if (tensors.empty()) {
|
||||
TORCH_CHECK(false, "Received an empty list");
|
||||
}
|
||||
|
|
@ -454,20 +452,7 @@ inline at::Tensor newLikeFlat(
|
|||
at::DeviceGuard gpuGuard(t.device());
|
||||
std::vector<int64_t> sizes{static_cast<int64_t>(tensors.size())};
|
||||
sizes.insert(sizes.end(), t.sizes().begin(), t.sizes().end());
|
||||
if (t.is_contiguous() ||
|
||||
!preserve_strides) { // we are checking for memory format, so tensor might
|
||||
// not be contiguous
|
||||
// TODO handle all non-overlapping-and-dense, although if the strides
|
||||
// disagree in ranks we are opening a door for more bugs than currently
|
||||
// where channels-last might disagree between ranks
|
||||
// fast path, don't call empty_strided
|
||||
return at::empty(sizes, t.options());
|
||||
} else {
|
||||
// memory-dense, but not necessarily contiguous tensor
|
||||
std::vector<int64_t> strides{t.numel()};
|
||||
strides.insert(strides.end(), t.strides().begin(), t.strides().end());
|
||||
return at::empty_strided(sizes, strides, t.options());
|
||||
}
|
||||
return at::empty(sizes, t.options());
|
||||
}
|
||||
|
||||
inline std::vector<std::vector<int64_t>> getSizes(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user