mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
### Description - This PR renames `_all_gather_base` to `all_gather_into_tensor` so that it is clearer in meaning. - The `all_gather_into_tensor` API differs from the `all_gather` API in the output it accepts -- a single, large tensor instead of a list of tensors. - This PR also adds deprecation warning to `_all_gather_base`. ### Issue `_all_gather_base` was implemented in https://github.com/pytorch/pytorch/pull/33924 to avoid unnecessary flattening. There was previous effort (#82639) to merge `_all_gather_base` with the existing `all_gather` API by detecting the parameter type passed in for the output. There are, however, two "blockers" that make the merge difficult: (i) The merge leads to backward compatibility break. We would need to change the parameter name `tensor_list` in `all_gather` to a general name `output` that can cover both tensor and tensor list. (ii) Recently, the `all_gather` API has added uneven tensor support, utilizing the tensor boundaries implied by the list. We are, however, not sure to add such support to the `_all_gather_base` function, because that would require users to pass in additional tensor boundary information. In view of the above, we decided to productize `_all_gather_base` as a separate function, but with a clearer name. ### Testing Added tests: - `test_all_gather_into_cat_tensor_cuda` -- output form as with `torch.cat`. For example: ``` >>> tensor_in tensor([1, 2], device='cuda:0') # Rank 0 tensor([3, 4], device='cuda:1') # Rank 1 >>> tensor_out tensor([1, 2, 3, 4], device='cuda:0') # Rank 0 tensor([1, 2, 3, 4], device='cuda:1') # Rank 1 ``` - `test_all_gather_into_stack_tensor_cuda` -- output form as with `torch.stack`. For example: ``` >>> tensor_out2 tensor([[1, 2], [3, 4]], device='cuda:0') # Rank 0 tensor([[1, 2], [3, 4]], device='cuda:1') # Rank 1 ``` The output form is determined by the shape of the output tensor passed by the user, no flag used. Cc @rohan-varma @mrshenli @crcrpar @ptrblck @H-Huang Pull Request resolved: https://github.com/pytorch/pytorch/pull/85686 Approved by: https://github.com/rohan-varma, https://github.com/crcrpar |
||
|---|---|---|
| .. | ||
| _static | ||
| _templates | ||
| community | ||
| elastic | ||
| notes | ||
| rpc | ||
| scripts | ||
| amp.rst | ||
| autograd.rst | ||
| backends.rst | ||
| benchmark_utils.rst | ||
| bottleneck.rst | ||
| checkpoint.rst | ||
| complex_numbers.rst | ||
| conf.py | ||
| config_mod.rst | ||
| cpp_extension.rst | ||
| cpp_index.rst | ||
| cuda._sanitizer.rst | ||
| cuda.rst | ||
| cudnn_persistent_rnn.rst | ||
| cudnn_rnn_determinism.rst | ||
| data.rst | ||
| ddp_comm_hooks.rst | ||
| deploy.rst | ||
| distributed.algorithms.join.rst | ||
| distributed.elastic.rst | ||
| distributed.optim.rst | ||
| distributed.rst | ||
| distributions.rst | ||
| dlpack.rst | ||
| docutils.conf | ||
| fft.rst | ||
| fsdp.rst | ||
| futures.rst | ||
| fx.rst | ||
| hub.rst | ||
| index.rst | ||
| jit_builtin_functions.rst | ||
| jit_language_reference_v2.rst | ||
| jit_language_reference.rst | ||
| jit_python_reference.rst | ||
| jit_unsupported.rst | ||
| jit_utils.rst | ||
| jit.rst | ||
| library.rst | ||
| linalg.rst | ||
| masked.rst | ||
| math-quantizer-equation.png | ||
| mobile_optimizer.rst | ||
| model_zoo.rst | ||
| monitor.rst | ||
| multiprocessing.rst | ||
| name_inference.rst | ||
| named_tensor.rst | ||
| nested.rst | ||
| nn.functional.rst | ||
| nn.init.rst | ||
| nn.rst | ||
| onnx_supported_aten_ops.rst | ||
| onnx.rst | ||
| optim.rst | ||
| package.rst | ||
| pipeline.rst | ||
| profiler.rst | ||
| quantization-accuracy-debugging.rst | ||
| quantization-backend-configuration.rst | ||
| quantization-support.rst | ||
| quantization.rst | ||
| random.rst | ||
| rpc.rst | ||
| sparse.rst | ||
| special.rst | ||
| storage.rst | ||
| tensor_attributes.rst | ||
| tensor_view.rst | ||
| tensorboard.rst | ||
| tensors.rst | ||
| testing.rst | ||
| torch.ao.ns._numeric_suite_fx.rst | ||
| torch.ao.ns._numeric_suite.rst | ||
| torch.overrides.rst | ||
| torch.rst | ||
| type_info.rst | ||