Commit Graph

393 Commits

Author SHA1 Message Date
ankurneog
f497a0039c API to retrieve default distributed backend from device (#140536)
# Motivation
The distributed APIs rely on backend names for creation of process group.
To abstract out references of these names from PG creation, an API is added to get default distributed backend for  device.
The device code would need to register its device and backend  via  ```torch.distributed.Backend.register_backend```  or  update the map ``` torch.distributed.Backend.default_device_backend_map["device"] = "distributed_backend" ```  prior to using the API.

An example of use is added in the test file ( which can be used to check abstracted APIs)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140536
Approved by: https://github.com/kwen2501
2024-11-22 11:01:53 +00:00
Will Constable
b25c291563 [C10D] Support group ranks in P2POp and batch_isend_irecv (#141054)
Changes semantic of __repr__ of P2POp: s, d are now group ranks instead
of global ranks. I think this is OK since I also updated the field names
to make this obvious.

Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in #140460

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141054
Approved by: https://github.com/kwen2501
2024-11-21 14:51:56 +00:00
Junjie Wang (PyTorch)
b44ecd91ba [c10d] Switch all timer logging in c10d to wait_counter (#141154)
Summary: The original decorator based time logger is bad in performance and capacity. So we want to replace it with pytorch `_WaitCounter` now.

Test Plan: Tested on workload and no regression has been seen: https://fburl.com/scuba/aps_instrumentation_components/mskj73ea

Differential Revision: D66218675

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141154
Approved by: https://github.com/wz337
2024-11-21 01:10:11 +00:00
Aaron Gokaslan
12e95aa4ee [BE]: Apply PERF401 autofixes from ruff (#140980)
* Automatically applies ruff rule 401. Turns loops into equivalent list comprehensions which are faster and do not leak the scope of the loop variables.
* list comprehensions not only often have better typing, but are 50+% faster than for loops on overhead. They also preserve length information etc and are better for the interpreter to optimize.
* Manually went back and made mypy happy after the change.
* Also fixed style lints in files covered by flake8 but not by pyfmt

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140980
Approved by: https://github.com/justinchuby, https://github.com/malfet
2024-11-20 17:52:07 +00:00
Ke Wen
70a0906f24 [c10d] Support optional backend if device_id provided (#140963)
Citing @malfet's [comment](https://github.com/pytorch/pytorch/pull/136343#pullrequestreview-2318792396) in https://github.com/pytorch/pytorch/pull/136343
> It would be great, if users do not have to modify their programs for every new backend, but rather use with torch.device('xpu'): and keep rest of the code unchanged.

This PR makes the backend specification ("nccl", "gloo") optional when user provides a `devce_id` to `init_process_group` (the acceptance of `device_id` has been previously supported for the purpose of eager init).

New user experience:
```
device = torch.device(device_type, rank % device_count)
dist.init_process_group(device_id=device)
```

The line of `device = torch.device(...)` is anyway needed because user would use it for tensor creation etc.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140963
Approved by: https://github.com/wconstab
2024-11-19 03:17:29 +00:00
Will Constable
98e6e69b1b [C10D] Support group_dst/group_src in c10d send/recv object_list (#140847)
Also add mypy annotations

Partially addresses RFC 0042 (https://github.com/pytorch/rfcs/pull/71)
See more details/motivation in https://github.com/pytorch/pytorch/pull/140460

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140847
Approved by: https://github.com/H-Huang
ghstack dependencies: #140843
2024-11-19 01:23:08 +00:00
Will Constable
c82c46ccc7 [C10D] support group_src/dst in broadcast/reduce ops (#140843)
Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in #140460
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140843
Approved by: https://github.com/kwen2501
2024-11-19 01:23:08 +00:00
Will Constable
625c24a7f9 [C10D] Support group_dst in scatter/gather (+object) ops (#140827)
Also add missing mypy typing and a few asserts to make mypy happy

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in #140460

Note: object collective version canonicalizes to global instead of group
rank, simply becuase this left more of the original code intact and
required less conversions overall.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140827
Approved by: https://github.com/kwen2501
2024-11-17 22:19:58 +00:00
Will Constable
f8891a764d [C10D] dedup send/recv impls (#140815)
Avoid copypaste of send/isend and recv/irecv impl.

This does change the warning issued from send to include the identifier
"isend" instead of "send", but I think thats not a big deal.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140815
Approved by: https://github.com/fegin
ghstack dependencies: #140460
2024-11-16 14:24:52 +00:00
Will Constable
3d4e68fad3 [C10D] Support group_dst/group_src in c10d send/recv (#140460)
Partly addressing RFC 0042 (https://github.com/pytorch/rfcs/pull/71)

It's annoying that 'dst' (for send) ust be a global rank even when a
group is passed in.  But we can't easily change 'dst' without breaking
existing cases.

Furthermore, requiring use of 'global' dst breaks the less common usage
pattern of creating a new ProcessGroup object that is not connected to
the 'default group' and thus has no logical 'global' ranks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140460
Approved by: https://github.com/d4l3k, https://github.com/kwen2501, https://github.com/fduwjj
2024-11-16 14:24:45 +00:00
fduwjj
ba8568f7fb [c10d][logging] Add wait counter for time spent in object to tensor and tensor to object (#140414)
Originally we want to leverage the timer logger to measure the time spent in object to tensor and tensor to object (https://github.com/pytorch/pytorch/pull/139757) But it gets reverted (internally) because of a performance regression. We now use wait counter instead which is more lightweight.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140414
Approved by: https://github.com/c-p-i-o, https://github.com/XilunWu, https://github.com/wz337
2024-11-13 21:10:43 +00:00
PyTorch MergeBot
e4195f8060 Revert "[logging][ez] Add timer logging for pickling and unpickle for object based collective (#139757)"
This reverts commit 41e4d88584.

Reverted https://github.com/pytorch/pytorch/pull/139757 on behalf of https://github.com/izaitsevfb due to reverted internally, see D65682470 ([comment](https://github.com/pytorch/pytorch/pull/139757#issuecomment-2471316405))
2024-11-12 18:53:37 +00:00
PyTorch MergeBot
1400fedf76 Revert "add supports_coalescing property in c10d::Backend to determine whether backend supports coalescing (#135338)"
This reverts commit e5574445b0.

Reverted https://github.com/pytorch/pytorch/pull/135338 on behalf of https://github.com/ZainRizvi due to Sorry but this is failing internally. Please see D65663382 for more details ([comment](https://github.com/pytorch/pytorch/pull/135338#issuecomment-2465911854))
2024-11-08 23:52:49 +00:00
taozhiwei
e5574445b0 add supports_coalescing property in c10d::Backend to determine whether backend supports coalescing (#135338)
1. My company is using privateuseone to connect new hardware device and requires the use of `batch_isend_irecv` function. However, `batch_isend_irecv` is currently only open to CUDA, so I add `supports_coalescing` property in `c10d::Backend` to determine whether backend supports coalescing.
2. If `pg._has_hooks` return True, We don't need to determine if the current device is CUDA. So privateuseone can also support `pg._wait_for_pending_works`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135338
Approved by: https://github.com/kwen2501
2024-11-08 11:08:45 +00:00
Junjie Wang (PyTorch)
41e4d88584 [logging][ez] Add timer logging for pickling and unpickle for object based collective (#139757)
Summary: As discussed, we want to measure the time spent during pickling and unpickle.

Test Plan: CI

Reviewed By: wz337

Differential Revision: D65462767

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139757
Approved by: https://github.com/awgu, https://github.com/Skylion007, https://github.com/fegin, https://github.com/c-p-i-o
2024-11-05 17:40:27 +00:00
Ke Wen
f121eab018 [c10d] Remove dead Dynamo marker (#139545)
Per discussion with @anijain2305, `dynamo_unsupported_distributed_c10d_ops` is not referenced anywhere.
Removing this dead code.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139545
Approved by: https://github.com/Skylion007
2024-11-03 00:40:26 +00:00
Shuqiang Zhang
4c91481656 [c10d] allow sub group to be eagerly inited even if default one is not (#138665)
Summary:
Currently, eager mode is applied either to all PGs or NONE of them.
There are cases where we don't want to initialize the comms for default
PG, but we still want to initialize the comms for sub PG. Now with a
device_id passed to new group, we can achieve this case
Test Plan:
newly added UT

Tags:

Resolves https://github.com/pytorch/pytorch/issues/137018

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138665
Approved by: https://github.com/kwen2501
ghstack dependencies: #138781
2024-10-24 23:51:28 +00:00
Shuqiang Zhang
fe458eef80 [c10d] fix a logic of using ncclCommSplit (#138781)
Summary:
Currently, whether split should be used depends on the size of subgroup.
It's possible that default PG is not eagerly initialized yet, but split is still
called.

This PR fixes this issue by removing split's  dependency on subgroup size
Test Plan:
Modified UT
Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138781
Approved by: https://github.com/kwen2501
2024-10-24 16:16:35 +00:00
Tom Ritchford
c0582fd0f8 Remove unused Python variables in torch/[b-z]* (#136963)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136963
Approved by: https://github.com/ezyang
2024-10-19 16:45:22 +00:00
Ke Wen
fecd370ea1 [c10d] Fix color value for comm split being negative (#137855)
Fixes https://github.com/pytorch/pytorch/issues/137856.

### Issue 1
Today under `ProcessGroupNCCL::Options`, color is declared as:
```
    int64_t split_color{0};
```
When passing this variable to `ncclCommSplit` which accepts `int`, the value may overflow and become negative, as in #137856. But NCCL API only accepts non-negative colors (or `NCCL_SPLIT_NOCOLOR`).

But that's not all.

### Issue 2
`split_color` is pybind'ed to python frontend. If we just change from `int64_t` to `int` in C++, pybind will complain:
```
[rank0]: TypeError: (): incompatible function arguments. The following argument types are supported:
[rank0]:     1. (self: torch._C._distributed_c10d.ProcessGroupNCCL.Options, arg0: int) -> None
```
This is because python `int` represents a wider range than C++ `int`. So we cannot pass hash values -- which are potentially big ints -- from python to C++. The PR modulo the hash value with `c_int`'s max value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137855
Approved by: https://github.com/wconstab
2024-10-19 03:17:19 +00:00
Ke Wen
fe148024fe [c10d][experimental] Add _abort_process_group (#132291)
Thanks @eqy for reminding me of this RFC: https://github.com/pytorch/pytorch/issues/119797

This PR is meant to:
- provide a way to abort multiple PGs without deadlocking each other.
- provide a possibility to manually handle comm errors or timeouts (and potentially recovery of such).
One can find an example from: https://github.com/NVIDIA/nccl/issues/1013

## How is it different from `destroy_process_group`?
`destroy_process_group` is meant for normal exit, while `_abort_process_group` is meant for bailout upon hangs or failures. Similar to `ncclCommDestroy` vs `ncclCommAbort`.

## What's new in `_abort_process_group`?
It added support for "group abort" semantic. The "group abort" semantic is capable of aborting multiple NCCL comms concurrently, avoiding deadlock in otherwise serialized `ncclCommAbort` executions. Details are in the [RFC](https://github.com/pytorch/pytorch/issues/119797) targeting [the hang issue in multi-comm case](https://github.com/NVIDIA/nccl/issues/1013). `Group abort` semantic is added in NCCL 2.22.

## What's next?
Ideally, the watchdog's behavior should support "group abort" too. But this is hard to implement today due to a lack of "global view" by each PG's individual watchdog. A big semi-big refactor may be needed to "uplift" the watchdogs to a global level or consolidate them into one (i.e. one dog watching multiple PGs).

In any case, it may not be a bad idea to experiment the "group abort" feature with a manual API first and then extend to the automatic mode (watchdog).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132291
Approved by: https://github.com/eqy
2024-10-11 05:04:17 +00:00
Shuqiang Zhang
47a515d260 [c10d] simplify barrier implementation and further decouple CPU/GPU (#137516)
synchronization
Summary:
Barrier is  essentially intended to block CPU thread (instead of GPU
streams). Before we used 2 stream synchronizations (1. current stream
blocked by nccl stream end event, 2. CPU thread blocked on current
stream). This is unnecessary as we already have CPU thread blocking
logic in wait(). Also, adding barrier specific code block in the general
GPU synchronize() API is intrusive and confusing.

This PR cleans this.

Test Plan:
CI

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137516
Approved by: https://github.com/fduwjj, https://github.com/kwen2501
2024-10-09 23:55:28 +00:00
Ke Wen
7631a04081 [c10d] Fix the device query story of ProcessGroup (#136790)
Function `_get_pg_default_device` is being used outside of `distributed_c10d.py`.

A concern is that people may not be aware of what it actually does, due to bad naming of this function:
`Return the device to use with ``group`` for control flow usage (object collectives, barrier).`

The remediation is as follows:
- Added a deprecation warning to `_get_pg_default_device`;
- Added a private function `_get_object_coll_device` to undertake what it does;
- Added a `_device_capability` function for users who want to query the device support of a PG -- it returns a plain list, no more "default" choice.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136790
Approved by: https://github.com/H-Huang
2024-10-03 01:36:22 +00:00
Howard Huang
0ccd39a64b Fix prefix store seg fault (#136872)
fixes https://github.com/pytorch/pytorch/issues/136723

Do not allow `None` to be passed into `PrefixStore`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136872
Approved by: https://github.com/kwen2501
2024-09-30 20:43:08 +00:00
Xu Song
5997354151 Add more distributed examples (#130427)
1. Add `gather` example
2. Add device to `scatter` example
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130427
Approved by: https://github.com/kwen2501
2024-09-20 18:27:27 +00:00
fduwjj
a0c7029a75 [c10d][Reland] Remove Option for ProcessGroup and Expose backend Options to reflect the correct code structure (#132931) (#135653)
We introduced the dispatchable backend for a ProcessGroup and collective in https://github.com/pytorch/pytorch/issues/86225. This PR is a follow-up cleanup to clean up the option of a ProcessGroup and ask users to either set timeout or backend later on or directly create backend after creating a PG.

Also PGNCCL is using option class from ProcessGroup but we actually should use Option from backend class. So this PR is to make the type or name to be aligned with what we are doing in cpp side. I don't change the signature for the public API, so they still use args named "pg_options"

We need to make changes to the test to make it aligned with the change.

This is try to reland D62008954 by fixing internal errors.

Differential Revision: [D62483294](https://our.internmc.facebook.com/intern/diff/D62483294/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135653
Approved by: https://github.com/wz337, https://github.com/H-Huang
2024-09-16 19:56:42 +00:00
PyTorch MergeBot
351ba3e67c Revert "[c10d] Remove Option for ProcessGroup and Expose backend Options to reflect the correct code structure (#132931)"
This reverts commit 65864d0134.

Reverted https://github.com/pytorch/pytorch/pull/132931 on behalf of https://github.com/ZainRizvi due to This PR is breaking builds internally due to the removal of ProcessGroup::Options ([comment](https://github.com/pytorch/pytorch/pull/132931#issuecomment-2321862402))
2024-08-30 16:27:40 +00:00
fduwjj
65864d0134 [c10d] Remove Option for ProcessGroup and Expose backend Options to reflect the correct code structure (#132931)
We introduced the dispatchable backend for a ProcessGroup and collective in https://github.com/pytorch/pytorch/issues/86225. This PR is a follow-up cleanup to clean up the option of a ProcessGroup and ask users to either set timeout or backend later on or directly create backend after creating a PG.

Also PGNCCL is using option class from ProcessGroup but we actually should use Option from backend class. So this PR is to make the type or name to be aligned with what we are doing in cpp side. I don't change the signature for the public API, so they still use args named "pg_options"

We need to make changes to the test to make it aligned with the change.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132931
Approved by: https://github.com/H-Huang
2024-08-29 22:40:12 +00:00
Chien-Chin Huang
3434a54fba [CP] Rewrite ring attention backward algorithm and enablement APIs (#131351)
**What does this PR achieve**
1. This PR rewrite ring attention backward algorithm to fuse the alltoall and overlap the gradient communication with computation.

2. Enables memory efficient attention with CP by templating the ring attention backward to verify the accuracy as fp32 gives us higher confident about the implementation correctness.

3. Provides some experimental APIs to enable context parallelism.

4. Ensures CP work with torch.compiler. The combination of causal masking and torch.compiler has not
yet worked.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131351
Approved by: https://github.com/wanchaol
2024-08-15 16:41:51 +00:00
Xuehai Pan
758a0a88a2 [BE][Easy] enable ruff rule PIE790: unnecessary pass statement (#133200)
This PR removes unnecessary `pass` statement. This is semanticly safe because the bytecode for the Python code does not change.

Note that if there is a docstring in the function, a empty function does not need a `pass` statement as placeholder.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133200
Approved by: https://github.com/malfet, https://github.com/eqy, https://github.com/kit1980
2024-08-15 15:50:19 +00:00
Du Jiangcun
b41d62a3a2 Fix typo in docs of all_gather (#133066)
Fix a typo of docs:
```
def all_gather(tensor_list, tensor, group=None, async_op=False):
...
        [tensor([0, 0], device='cuda:0'), tensor([0, 0], device='cuda:1')] # Rank 1
```
`cuda:0` should be `cuda:1`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133066
Approved by: https://github.com/awgu
2024-08-09 18:21:26 +00:00
Will Constable
2dbe5cb979 [C10D] Clarify warning for concurrent PG usage (#131895)
Addresses a common misconception about safety of using multiple NCCL
process groups from PyTorch.

Notably, it IS safe to use multiple process groups, so long as
communication operations from different groups are not allowed to
overlap.  (Overlap of communication operations from one group with
compute operations IS ok).

TODO: after getting feedback on the text, update other copies of the warning on other APIs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131895
Approved by: https://github.com/fduwjj
2024-08-09 17:06:46 +00:00
fduwjj
4e610924d4 [c10d] Add a new API for adding ephemeral timeout for one local rank and the timeout will reset when the first collective finishes (#130905)
We provide an API for user to add ephemeral timeout across all PGs within one rank and the timeout will reset when the first collective issued after the timeout added finishes.

Each extension only covers collectives after the issue and before the first collective finished. The diagram below shows how the timeout changes:

<img width="1174" alt="image" src="https://github.com/user-attachments/assets/354923b7-581c-40de-ae0f-1cd3da273ccc">

While this feature provides flexibility in specific scenarios, it introduces statefulness to timeout setting. Therefore, it is advisable to use this API sparingly and consider alternative approaches, such as directly setting the timeout or utilizing a barrier collective (one can set any timeout to the barrier), whenever feasible.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130905
Approved by: https://github.com/ezyang
2024-08-06 03:47:58 +00:00
Oguz Ulgen
72d2dba992 Add None return type to init (#132335)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132335
Approved by: https://github.com/albanD
2024-08-01 15:26:45 +00:00
Ke Wen
b2118573d6 [BE] Unify PG assignments (#132230)
python's `or` operator returns `bar` in cases of
`foo = None or bar`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132230
Approved by: https://github.com/Skylion007, https://github.com/wconstab
2024-07-31 15:28:25 +00:00
Shuqiang Zhang
8158cf2f59 [c10d] Fix split_group usage when there is a single rank (#131824)
Summary:
This is a request from xlformer team to allow single rank PG/comms
Test Plan:
UT

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131824
Approved by: https://github.com/pavanbalaji, https://github.com/fduwjj
2024-07-26 18:11:17 +00:00
Shuqiang Zhang
4aef5a1134 [c10] add an option to pg_config split share (#130877)
Summary:
context is: #129865
We want to give users an option to not share comms resouces so that
comm opts can overlap
Test Plan:
Augmentd UT

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130877
Approved by: https://github.com/fduwjj
2024-07-19 21:11:26 +00:00
PyTorch MergeBot
5f3d8b8788 Revert "[c10] add an option to pg_config split share (#130877)"
This reverts commit 367213a608.

Reverted https://github.com/pytorch/pytorch/pull/130877 on behalf of https://github.com/atalman due to breaks internal build ([comment](https://github.com/pytorch/pytorch/pull/130877#issuecomment-2239298810))
2024-07-19 14:24:50 +00:00
Shuqiang Zhang
367213a608 [c10] add an option to pg_config split share (#130877)
Summary:
context is: #129865
We want to give users an option to not share comms resouces so that
comm opts can overlap
Test Plan:
Augmentd UT

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130877
Approved by: https://github.com/fduwjj
2024-07-18 19:03:00 +00:00
Shuqiang Zhang
77fb5b0e23 [c10d] a new Pytorch API (split_group) to create a process group (#130507)
This is the implementation following the RFC: https://github.com/pytorch/pytorch/issues/130407

ncclCommSplit
Summary:
In current Pytorch/c10d, the new_group API is used to create a new
process group from the default pg.  When device_id is specified in
init_process_group and nccl is used as the backend, the new_group call
will use ncclCommSplit to create the nccl communicators to save
communicator resources. It has a few drawbacks:

Redundant calls
Suppose the default group has 256 ranks, we need to have 32 children PGs
and each child PG has 8 ranks. in this case, each rank needs to call
new_group and ncclCommSplit 32 times because of how we implement
new_group API and the collective requirement of ncclCommSplit. For a
specific global rank, 31 calls of ncclCommSplit would be no_color split,
and only 1 of them is colored split. With the proposed new split_group
API, we expect only 1 call of split_group/ncclCommSplit is needed per
rank in the above example case

new_group can only split from default_pg
Ideally, a new pg should be able to be split from any pg

With the new split_group API, users can create new PGs using
ncclCommSplit with less number of calls and initialize the PG eagerly.
This is also useful in the cases of creating many P2P communicators.
Test Plan:
New UTs:
e.g., python test/distributed/test_c10d_nccl.py -k
test_comm_split_group_larger_scale
Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130507
Approved by: https://github.com/wconstab
2024-07-15 21:26:43 +00:00
Will Constable
83a4a8b510 [C10D] clean up pointless 'or None' clause (#129522)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129522
Approved by: https://github.com/awgu
2024-06-27 22:40:11 +00:00
Yifu Wang
bbd47f7b2f Remove ProcessGroupCudaP2P and change async-TP to use SymmetricMemory (#128762)
This PR removes `ProcessGroupCudaP2P` and changes async-TP to use `SymmetricMemory`. The async-TP implementation is still workspace-based, but it now doesn't require a buffer size to be specified upfront.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128762
Approved by: https://github.com/wanchaol
2024-06-25 22:32:21 +00:00
Will Constable
e1499f6342 [C10D] Make new_group eager when used with comm_split (#129284)
If users pass `device_id` to init_process_group, they enable eager init
for the default group.  Then if they subsequently call `new_group`, the
device_id argument is not required as it should be assumed to match the
one used for init_process_group.

However, both `init_process_group` and `new_group` apis share a helper
function, which expects a `device_id` value that defaults to None.  When
it's None, eager initialization is disabled.

This PR ensures that if a device_id was passed to init_process_group,
the same device_id will automatically be fed into the helper function
for any new_group calls that follow.

**Test plan**
I found an existing test in CI  `test_comm_split_subgroup` that failed after my change, because it was asserting that backend comm_split counter did not increment eagerly, and its behavior had changed to increment eagerly.  I updated the test in the PR to pass with my change.

I also tested locally via simple program with TORCH_CPP_LOG_LEVEL=INFO and
observed eager initialization of the 'lows' and 'highs' PGs before the
'Here' print.

```
import torch
import torch.distributed as dist
dist.init_process_group(backend="nccl", device_id =torch.device(f"cuda:{torch.distributed.get_node_local_rank(0)}"))
dist.new_group([0, 1], group_desc="lows")
dist.new_group([2, 3], group_desc="highs")
print("Here")
torch.distributed.destroy_process_group()
```

Output:
https://gist.github.com/wconstab/88a5ba0b970244ca1f79133f989e0349

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129284
Approved by: https://github.com/pavanbalaji, https://github.com/fduwjj, https://github.com/d4l3k, https://github.com/nvcastet
2024-06-25 21:09:34 +00:00
Xuehai Pan
94dc3253a0 [BE][Easy] enable UFMT for torch/distributed/ (#128870)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128870
Approved by: https://github.com/fegin, https://github.com/wconstab
2024-06-22 18:53:28 +00:00
PyTorch MergeBot
9c929f6ce9 Revert "[BE][Easy] enable UFMT for torch/distributed/ (#128870)"
This reverts commit a0e1e20c41.

Reverted https://github.com/pytorch/pytorch/pull/128870 on behalf of https://github.com/fbgheith due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/128870#issuecomment-2181780356))
2024-06-21 00:38:28 +00:00
Xuehai Pan
a0e1e20c41 [BE][Easy] enable UFMT for torch/distributed/ (#128870)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128870
Approved by: https://github.com/fegin
ghstack dependencies: #128868, #128869
2024-06-18 21:49:08 +00:00
loganthomas
d77a1aaa86 DOC: add note about same sized tensors to dist.gather() (#128676)
Fixes #103305

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128676
Approved by: https://github.com/wconstab
2024-06-18 18:26:07 +00:00
Aaron Orenstein
3a0d088517 Flip default value for mypy disallow_untyped_defs [5/11] (#127842)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127842
Approved by: https://github.com/oulgen
2024-06-08 18:49:18 +00:00
Xuehai Pan
67ef2683d9 [BE] wrap deprecated function/class with typing_extensions.deprecated (#127689)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

Resolves #126888

- #126888

This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
2024-06-02 12:30:43 +00:00
PyTorch MergeBot
033e733021 Revert "[BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)"
This reverts commit 749a132fb0.

Reverted https://github.com/pytorch/pytorch/pull/126898 on behalf of https://github.com/fbgheith due to switching typing-extensions=4.3.0 to 4.9.0 causes internal failure ([comment](https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456))
2024-05-31 19:47:24 +00:00