This is the implementation following the RFC: https://github.com/pytorch/pytorch/issues/130407
ncclCommSplit
Summary:
In current Pytorch/c10d, the new_group API is used to create a new
process group from the default pg. When device_id is specified in
init_process_group and nccl is used as the backend, the new_group call
will use ncclCommSplit to create the nccl communicators to save
communicator resources. It has a few drawbacks:
Redundant calls
Suppose the default group has 256 ranks, we need to have 32 children PGs
and each child PG has 8 ranks. in this case, each rank needs to call
new_group and ncclCommSplit 32 times because of how we implement
new_group API and the collective requirement of ncclCommSplit. For a
specific global rank, 31 calls of ncclCommSplit would be no_color split,
and only 1 of them is colored split. With the proposed new split_group
API, we expect only 1 call of split_group/ncclCommSplit is needed per
rank in the above example case
new_group can only split from default_pg
Ideally, a new pg should be able to be split from any pg
With the new split_group API, users can create new PGs using
ncclCommSplit with less number of calls and initialize the PG eagerly.
This is also useful in the cases of creating many P2P communicators.
Test Plan:
New UTs:
e.g., python test/distributed/test_c10d_nccl.py -k
test_comm_split_group_larger_scale
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130507
Approved by: https://github.com/wconstab
This PR removes `ProcessGroupCudaP2P` and changes async-TP to use `SymmetricMemory`. The async-TP implementation is still workspace-based, but it now doesn't require a buffer size to be specified upfront.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128762
Approved by: https://github.com/wanchaol
If users pass `device_id` to init_process_group, they enable eager init
for the default group. Then if they subsequently call `new_group`, the
device_id argument is not required as it should be assumed to match the
one used for init_process_group.
However, both `init_process_group` and `new_group` apis share a helper
function, which expects a `device_id` value that defaults to None. When
it's None, eager initialization is disabled.
This PR ensures that if a device_id was passed to init_process_group,
the same device_id will automatically be fed into the helper function
for any new_group calls that follow.
**Test plan**
I found an existing test in CI `test_comm_split_subgroup` that failed after my change, because it was asserting that backend comm_split counter did not increment eagerly, and its behavior had changed to increment eagerly. I updated the test in the PR to pass with my change.
I also tested locally via simple program with TORCH_CPP_LOG_LEVEL=INFO and
observed eager initialization of the 'lows' and 'highs' PGs before the
'Here' print.
```
import torch
import torch.distributed as dist
dist.init_process_group(backend="nccl", device_id =torch.device(f"cuda:{torch.distributed.get_node_local_rank(0)}"))
dist.new_group([0, 1], group_desc="lows")
dist.new_group([2, 3], group_desc="highs")
print("Here")
torch.distributed.destroy_process_group()
```
Output:
https://gist.github.com/wconstab/88a5ba0b970244ca1f79133f989e0349
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129284
Approved by: https://github.com/pavanbalaji, https://github.com/fduwjj, https://github.com/d4l3k, https://github.com/nvcastet
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
Resolves#126888
- #126888
This PR is split from PR #126898.
- #126898
------
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.
Resolves#126888
- #126888
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
As for now, the document of distributed.new_group() says that it returns `None` when current ranks is not in the new created process group. However, it actually returns `GroupMember.NON_GROUP_MEMBER`. I have check the code and think it is more appropriate that we fix the document.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122703
Approved by: https://github.com/wconstab, https://github.com/kwen2501
Summary:
ProcessGroupNCCL set up group_name/desc in c10d log and NCCL when initializing nccl communicator. In eager initialization mode, pg_name and pg_desc is set after communicator initialization so the information won't be available in pytorch log or NCCL communicator.
This PR fix this by setting pg_name/desc earlier
Differential Revision: D57759816
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127053
Approved by: https://github.com/wconstab, https://github.com/kwen2501
## Context
This stack prototypes automatic micro-pipelining of `all-gather -> matmul` and `matmul -> reduce-scatter` via Inductor. The idea originates from the paper [Overlap Communication with Dependent Computation via
Decomposition in Large Deep Learning Models](https://dl.acm.org/doi/pdf/10.1145/3567955.3567959). The implementation and some key optimizations are heavily influenced by @lw's implementation in xformers.
The stack contains several components:
- `ProcessGroupCudaP2P` - a thin wrapper around `ProcessGroupNCCL`. It in addition maintains a P2P workspace that enables SM-free, one-sided P2P communication which is needed for optimal micro-pipelining.
- `fused_all_gather_matmul` and `fused_matmul_reduce_scatter` dispatcher ops.
- Post-grad fx pass that detects `all-gather -> matmul` and `matmul -> reduce-scatter` and replaces them with the fused dispatcher ops.
To enable the prototype feature:
- Set the distributed backend to `cuda_p2p`.
- Set `torch._inductor.config._micro_pipeline_tp` to `True`.
*NOTE: the prototype sets nothing in stone w.r.t to each component's design. The purpose is to have a performant baseline with reasonable design on which each component can be further improved.*
## Benchmark
Setup:
- 8 x H100 (500W) + 3rd gen NVSwitch.
- Llama3 8B training w/ torchtitan.
- 8-way TP. Reduced the number of layers from 32 to 8 for benchmarking purpose.
Trace (baseline): https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/yifu_tmpjaz8zgx0
<img width="832" alt="image" src="https://github.com/pytorch/pytorch/assets/4156752/4addba77-5abc-4d2e-93ea-f68078587fe1">
Trace (w/ micro pipelining): https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/yifu_tmpn073b4wn
<img width="963" alt="image" src="https://github.com/pytorch/pytorch/assets/4156752/4f44e78d-8196-43ab-a1ea-27390f07e9d2">
## This PR
`ProcessGroupCudaP2P` is a thin wrapper around `ProcessGroupNCCL`. By default, it routes all collectives to the underlying `ProcessGroupNCCL`. In addition, `ProcessGroupCudaP2P` initializes a P2P workspace that allows direct GPU memory access among the members. The workspace can be used in Python to optimize intra-node communication patterns or to create custom intra-node collectives in CUDA.
`ProcessGroupCudaP2P` aims to bridge the gap where certain important patterns can be better optimized via fine-grained P2P memory access than with collectives in the latest version of NCCL. It is meant to complement NCCL rather than replacing it.
Usage:
```
# Using ProcessGroupCudaP2P
dist.init_process_group(backend="cuda_p2p", ...)
# Using ProcessGroupCudaP2P while specifying ProcessGroupCudaP2P.Options
pg_options = ProcessGroupCudaP2P.Options()
dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)
# Using ProcessGroupCudaP2P while specifying ProcessGroupNCCL.Options
pg_options = ProcessGroupNCCL.Options()
dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)
# Using ProcessGroupCudaP2P while specifying both
# ProcessGroupCudaP2P.Options and ProcessGroupNCCL.Options
pg_options = ProcessGroupCudaP2P.Options()
pg_options.nccl_options = ProcessGroupNCCL.Options()
dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)
# Down-casting the backend to access p2p buffers for cuda_p2p specific
# optimizations
if is_cuda_p2p_group(group):
backend = get_cuda_p2p_backend(group)
if required_p2p_buffer_size > backend.get_buffer_size():
# fallback
p2p_buffer = backend.get_p2p_buffer(...)
else:
# fallback
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122163
Approved by: https://github.com/wanchaol
## Context
This stack prototypes automatic micro-pipelining of `all-gather -> matmul` and `matmul -> reduce-scatter` via Inductor. The idea originates from the paper [Overlap Communication with Dependent Computation via
Decomposition in Large Deep Learning Models](https://dl.acm.org/doi/pdf/10.1145/3567955.3567959). The implementation and some key optimizations are heavily influenced by @lw's implementation in xformers.
The stack contains several components:
- `ProcessGroupCudaP2P` - a thin wrapper around `ProcessGroupNCCL`. It in addition maintains a P2P workspace that enables SM-free, one-sided P2P communication which is needed for optimal micro-pipelining.
- `fused_all_gather_matmul` and `fused_matmul_reduce_scatter` dispatcher ops.
- Post-grad fx pass that detects `all-gather -> matmul` and `matmul -> reduce-scatter` and replaces them with the fused dispatcher ops.
To enable the prototype feature:
- Set the distributed backend to `cuda_p2p`.
- Set `torch._inductor.config._micro_pipeline_tp` to `True`.
*NOTE: the prototype sets nothing in stone w.r.t to each component's design. The purpose is to have a performant baseline with reasonable design on which each component can be further improved.*
## Benchmark
Setup:
- 8 x H100 (500W) + 3rd gen NVSwitch.
- Llama3 8B training w/ torchtitan.
- 8-way TP. Reduced the number of layers from 32 to 8 for benchmarking purpose.
Trace (baseline): https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/yifu_tmpjaz8zgx0
<img width="832" alt="image" src="https://github.com/pytorch/pytorch/assets/4156752/4addba77-5abc-4d2e-93ea-f68078587fe1">
Trace (w/ micro pipelining): https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/yifu_tmpn073b4wn
<img width="963" alt="image" src="https://github.com/pytorch/pytorch/assets/4156752/4f44e78d-8196-43ab-a1ea-27390f07e9d2">
## This PR
`ProcessGroupCudaP2P` is a thin wrapper around `ProcessGroupNCCL`. By default, it routes all collectives to the underlying `ProcessGroupNCCL`. In addition, `ProcessGroupCudaP2P` initializes a P2P workspace that allows direct GPU memory access among the members. The workspace can be used in Python to optimize intra-node communication patterns or to create custom intra-node collectives in CUDA.
`ProcessGroupCudaP2P` aims to bridge the gap where certain important patterns can be better optimized via fine-grained P2P memory access than with collectives in the latest version of NCCL. It is meant to complement NCCL rather than replacing it.
Usage:
```
# Using ProcessGroupCudaP2P
dist.init_process_group(backend="cuda_p2p", ...)
# Using ProcessGroupCudaP2P while specifying ProcessGroupCudaP2P.Options
pg_options = ProcessGroupCudaP2P.Options()
dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)
# Using ProcessGroupCudaP2P while specifying ProcessGroupNCCL.Options
pg_options = ProcessGroupNCCL.Options()
dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)
# Using ProcessGroupCudaP2P while specifying both
# ProcessGroupCudaP2P.Options and ProcessGroupNCCL.Options
pg_options = ProcessGroupCudaP2P.Options()
pg_options.nccl_options = ProcessGroupNCCL.Options()
dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)
# Down-casting the backend to access p2p buffers for cuda_p2p specific
# optimizations
if is_cuda_p2p_group(group):
backend = get_cuda_p2p_backend(group)
if required_p2p_buffer_size > backend.get_buffer_size():
# fallback
p2p_buffer = backend.get_p2p_buffer(...)
else:
# fallback
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122163
Approved by: https://github.com/wanchaol
fixes#126379
This is the easy fix. An additional fix that I did not do is to
deregister the excepthook (or rather, restore the orignal one) when
calling dist.destroy_process_group. This might be a bit complicated in
practice, so landing as is for now.
Also, couldn't figure out a clean way to test this. assertRaisesRegex
wasn't getting a string value, probably becuase of the stderr
redirection done via the excepthook in the first place.
Output from the original repro is cleaned up with the fix:
```
[rank0]: Traceback (most recent call last):
[rank0]: File "/data/users/whc/pytorch/except.py", line 6, in <module>
[rank0]: raise ZeroDivisionError
[rank0]: ZeroDivisionError
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126739
Approved by: https://github.com/yf225
Addresses follow up comments on #123992 and allows the use case of
writing code that checks `get_node_local_rank(fallback_rank=0)` and
runs correctly whether in the presence of a launcher (e.g. torchrun),
or run locally on a single device.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126737
Approved by: https://github.com/shuqiangzhang
The current call passes in `['/actual/path']` to os.walk which is a string pointing to no path and thus silently leads to and empty traversal.
There is an unused function just above that handles that, so I guess this is what was supposed to be called.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126103
Approved by: https://github.com/suo
Update ruff to 0.4.1 .
This version fixes a lot false negatives/false positives, is 20-40% faster, and has various other bug fixes.
Below is a before and after table showing the execution time of ruff lint and ruff format in milliseconds courtesy of https://astral.sh/blog/ruff-v0.4.0
| Repository | Linter (v0.3) | Linter (v0.4) | Formatter (v0.3) | Formatter (v0.4) |
|----------------------------------------------------|---------------|---------------|------------------|------------------|
| [pytorch/pytorch](https://github.com/pytorch/pytorch) | 328.7 | 251.8 | 351.1 | 274.9 |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124549
Approved by: https://github.com/ezyang
Fixes a bug where a reference to `_ProcessGroupWrapper` is used without first checking whether gloo is available. This fails on pytorch builds that do not include gloo becuase `_ProcessGroupWrapper` is only pybinded when building with `USE_GLOO=1`. Therefore, creation of a new process group fails with a `NameError` when only NCCL is available as the backend.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124233
Approved by: https://github.com/rohan-varma, https://github.com/d4l3k
Summary:
This ENV was introduced to safely rollout the behavior change in destroy
process group (e.g., call ncclCommsAbort). Now that this behavior change
were already rolled out, we no longer need this env and we should clean
up it to keep our code cleaner
Test Plan:
Modified/existing ut pass
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124334
Approved by: https://github.com/wconstab
Summary:
As part of the work of unifying process group identifier, log <group_name, group_desc>, instead of pg uid in profiler.
- group_name remains as the unique identifier, e.g. “0”, "1"
- group_desc will be the user specified name, e.g. "fsdp".
Reviewed By: aaronenyeshi, kwen2501
Differential Revision: D55610682
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124035
Approved by: https://github.com/aaronenyeshi
Summary:
We need a way to allow user set a customized description for a process group, e.g. FSDP, PP.
Here are several use cases of user specified group_desc:
- Logging: we can easily match a log line and understand what's this collective/pg is used to.
- Pytorch traces (e.g. Kineto, Execution Trace) can benefit from the PG desc since trace analysis, benchmarks will be able to easily differentiate PG purpose like FSDP, PP.
- Lower layer collectives(e.g. NCCL) debug: we will be able to expose PG desc to NCCL communicator so NCCL layer operations can be easily correlated to a PG.
Solution: Add a group_desc field to c10d
Differential Revision: D55781850
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123472
Approved by: https://github.com/kwen2501
Summary:
Pass python c10d group_name to c++ ProcessGroupNCCL so that the pg name will be consistent across different layers.
Also record pg_name in flight recorder entry.
Differential Revision: D55597200
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123117
Approved by: https://github.com/wconstab
Summary:
Minor logging cleanup in distributed library
1. Don't use "f" formatted strings - address linter issues.
2. Nits: Make use of unused `e` (error) in a few logs.
3. Change info->debug as asked in issue #113545
4. Nit: rename log -> logger in a few files for consistency
5. Fix a linter error.
Test Plan:
1. Local build passes.
2. Linter is happy.
Reviewers: wanchaol
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122921
Approved by: https://github.com/wanchaol
Summary: Process Group config is essential to analyze collective pattern. We have added this in Execution Trace. Now expose this information in Kineto as well
Differential Revision: D53557965
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119443
Approved by: https://github.com/kwen2501
Summary:
This PR tries to resolve issue #119215.
Basically, processgroup shutdown (and hence ncclCommAbort) is called in
destroy_process_group APIs for the corresponding PGs. and in the
destructor of ProcessGroup, we avoid calling abort/ncclCommAbort.
Instead, it just checks if the users have explicitly already called destroy_process_group. If
not, Destructor will log a warning and encourage/expect users to do so
for cleanup of resources of PGs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119250
Approved by: https://github.com/minsii, https://github.com/kwen2501