Reverts #163712 and forces allgather/scatter inputs/outputs to be contiguous (#166181)

Per title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166181
Approved by: https://github.com/kwen2501
This commit is contained in:
Natalia Gimelshein 2025-10-25 02:43:07 +00:00 committed by PyTorch MergeBot
parent 761f946043
commit 2efcf3ca98
4 changed files with 12 additions and 47 deletions

View File

@ -3817,27 +3817,6 @@ class NcclProcessGroupWithDispatchedCollectivesTests(
dist.all_gather_into_tensor(output_tensor, tensor)
self.assertEqual(output_tensor, tensor)
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_allgather_noncontig(self):
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
"nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
)
device = "cuda"
tensor = (
torch.arange(0, 16, device=torch.device(device))
.view(2, 2, 2, 2)
.to(memory_format=torch.channels_last)
)
tensor_list = [torch.empty_like(tensor) for _ in range(self.world_size)]
dist.all_gather(tensor_list, tensor)
for o in tensor_list:
self.assertEqual(o, tensor)
@requires_nccl()
@skip_if_lt_x_gpu(1)
@parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])

View File

@ -1381,8 +1381,7 @@ class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork {
// Use single flat output tensor.
// The first dimension corresponds to the index into outputs[N],
// so copying into the actual output later is easy.
at::Tensor flatOutputTensor =
newLikeFlat(outputs[0], /*preserve_strides*/ false);
at::Tensor flatOutputTensor = newLikeFlat(outputs[0]);
GENERATE_ALL_TYPES(scalarType, setOutput, opts, flatOutputTensor);
gloo::allgather(opts);
@ -1399,7 +1398,7 @@ class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork {
}
const std::vector<at::Tensor> getOutputTensors() override {
return {newLikeFlat(outputs[0], /*preserve_strides*/ false)};
return {newLikeFlat(outputs[0])};
}
void run() override {
@ -1695,7 +1694,7 @@ class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork {
}
const std::vector<at::Tensor> getOutputTensors() override {
return {newLikeFlat(output_lists[0], /*preserve_strides*/ false)};
return {newLikeFlat(output_lists[0])};
}
void run() override {
@ -1819,7 +1818,7 @@ class AsyncGatherWork : public ProcessGroupGloo::AsyncWork {
// This is later scattered to the separate output tensors.
at::Tensor flatOutputTensor;
if (context_->rank == root) {
flatOutputTensor = newLikeFlat(outputs[0], /*preserve_strides*/ false);
flatOutputTensor = newLikeFlat(outputs[0]);
GENERATE_ALL_TYPES(scalarType, setOutput, opts, flatOutputTensor);
}
@ -1842,8 +1841,7 @@ class AsyncGatherWork : public ProcessGroupGloo::AsyncWork {
const std::vector<at::Tensor> getOutputTensors() override {
return outputs.empty() ? std::vector<at::Tensor>{}
: std::vector<at::Tensor>{newLikeFlat(
outputs[0], /*preserve_strides*/ false)};
: std::vector<at::Tensor>{newLikeFlat(outputs[0])};
}
void run() override {
@ -2059,8 +2057,7 @@ class AsyncScatterWork : public ProcessGroupGloo::AsyncWork {
const std::vector<at::Tensor> getInputTensors() override {
return inputs.empty() ? std::vector<at::Tensor>{}
: std::vector<at::Tensor>{newLikeFlat(
inputs[0], /*preserve_strides*/ false)};
: std::vector<at::Tensor>{newLikeFlat(inputs[0])};
}
const std::vector<at::Tensor> getOutputTensors() override {

View File

@ -4770,6 +4770,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
bool same_size = check_same_size(outputTensors_);
if (same_size) {
// Flatten a vector of tensors into a single, stacked tensor.
// we can handle only contiguous inputs, because we are
// just sending ptr and numel to nccl
inputTensor = inputTensor.contiguous();
at::Tensor outputFlattened = newLikeFlat(outputTensors_);
return collective(
@ -4917,6 +4920,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
bool same_size = check_same_size(inputTensors_);
if (same_size) {
// Flatten a vector of tensors into a single, stacked tensor.
outputTensor = outputTensor.contiguous();
at::Tensor inputFlattened = newLikeFlat(inputTensors_);
return collective(

View File

@ -444,9 +444,7 @@ inline at::Tensor newLikeFlat(
sizes, strides, t.options().memory_format(std::nullopt));
}
inline at::Tensor newLikeFlat(
std::vector<at::Tensor>& tensors,
bool preserve_strides = true) {
inline at::Tensor newLikeFlat(std::vector<at::Tensor>& tensors) {
if (tensors.empty()) {
TORCH_CHECK(false, "Received an empty list");
}
@ -454,20 +452,7 @@ inline at::Tensor newLikeFlat(
at::DeviceGuard gpuGuard(t.device());
std::vector<int64_t> sizes{static_cast<int64_t>(tensors.size())};
sizes.insert(sizes.end(), t.sizes().begin(), t.sizes().end());
if (t.is_contiguous() ||
!preserve_strides) { // we are checking for memory format, so tensor might
// not be contiguous
// TODO handle all non-overlapping-and-dense, although if the strides
// disagree in ranks we are opening a door for more bugs than currently
// where channels-last might disagree between ranks
// fast path, don't call empty_strided
return at::empty(sizes, t.options());
} else {
// memory-dense, but not necessarily contiguous tensor
std::vector<int64_t> strides{t.numel()};
strides.insert(strides.end(), t.strides().begin(), t.strides().end());
return at::empty_strided(sizes, strides, t.options());
}
return at::empty(sizes, t.options());
}
inline std::vector<std::vector<int64_t>> getSizes(