Commit Graph

42 Commits

Author SHA1 Message Date
Daniel Vega-Myhre
ae29f054f5 [Async TP] More robust support for rowwise scales when fusing matmul reduce-scatter (#149247)
Part of https://github.com/pytorch/torchtitan/issues/866

## Context
- Async TP needs to support the "reshape -> scaled_mm -> reshape" pattern because scaled mm only supports 2D input tensors and 2D scales.
    - (a,b,c) => (a*b,c)
    - (a\*b,c) @ (c,d) = (a\*b,d)
    - (a\*b,d) => (a,b,d)

- Currently the implementation does not support scaled mm with rowwise scales **for all cases** of the reshape -> scaled_mm -> reshape pattern. The minimal example of this pattern is confirmed to work via this [unit test](00a2c68f67/test/distributed/tensor/parallel/test_micro_pipeline_tp.py (L406)), but more involved e2e examples in torchtitan fail silently (more context in final bullet point).
- Previously, the "A tensor" **node** referenced in the async TP graph manipulation code is the 3D+ node before the reshape, but the "A_scale" node is the 2d node from after the reshape, so they are incompatible.
- I previously implemented a simpler solution to this problem in https://github.com/pytorch/pytorch/pull/148001, with a [unit test](https://github.com/pytorch/pytorch/pull/148001/files#diff-115f1d0852382c9b58f22640d80999d879b33618e5f6c633fc9e4d0ca9781cecR406) confirming the fused node is indeed in the graph for the minimal example of the reshape->mm->reshape pattern. I also confirmed via manual e2e testing w/ torchtitan that the crash I was fixing no longer occurred. However, it turns out due to this [bug in torchtitan](https://github.com/pytorch/torchtitan/issues/866)  it was causing async TP to fail silently and fall back to vanilla TP, hiding the fact that this original solution fixed the crash but the fusion would not occur for rowwise scales. Thus, more robust solution is needed to support all cases.

## Solution TL;DR
- Use the 2D 'A' tensor and corresponding 2D scales as input to the fused_matmul_reduce_scatter implementation, instead of the 3D+ tensor/scales.
- Track the "pre mm reshape" and "post mm reshape" separately, to be referenced in the `fused_scaled_matmul_reduce_scatter` implementation, to update the scatter dim through the pre-mm reshape, and apply the post-mm reshape before applying the reduce scatter and returning the output tensor.
- Separate the `fused_matmul_reduce_scatter` and the `fused_scaled_matmul_reduce_scatter` code paths, to simplify them both.
- By fixing the bug in torchtitan (PR https://github.com/pytorch/torchtitan/pull/965) and implementing support for rowwise scales in pytorch in this PR, together these changes will solve the problem of how to support rowwise scales with all types of AC.

## Additional details for reviewers
To use the 2D A tensor while also supporting the "reshape -> mm -> reshape" pattern, the following other changes were needed:
- Track the pre-mm reshape, as it will affect the scatter dim used in the fused_matmul_reduce_scatter impementation.
- Track the post-mm reshape, as it will affect the output shape used in the fused_matmul_reduce_scatter impementation
- Based on the pre-mm reshape and the original scatter dim, calculate the new scatter dim for the 2D tensor. This is needed because during the pipelined producer mm implementation, the scatter dim is moved to dim 0 (so it can be sharded along the first dim and then get chunks to do mm ops on by indexing into the first dim), then moved back to it's original place before the reduce-scatter.
- Use the tracked post-mm reshape to reshape the stacked partial 2D outputs of the mm ops into 3D outputs needed for 1) the reduce-scatter w/ the original scatter dim, and 2) the expected output shape to prevent shape errors with subsequent ops.

## Test plan
- All existing unit tests passing.
- Expand unit tests for rowwise scales to test more scatter dims
- Added unit tests enforcing that async TP fails fast / throws an error if it fails to perform any fusions. Previously it just "failed silently" (fell back to vanilla TP without the user knowing) which has led to confusion, so this will improve the UX.
- Compared loss curves of bf16 vs float8 w/ rowwise scales to confirm integrity of numerics
- Confirmed via manual testing with torchtitan and inspecting the compile graph that the fusion is working as intended for:
    - bfloat16
    - float8 with tensorwise scales
    - float8 with rowwise scales

## Loss curves

Loss curves are virtually identical for bf16 + vanilla TP versus float8 with rowwise scales + async TP:

<img width="1017" alt="loss_async_tp" src="https://github.com/user-attachments/assets/4995db78-7012-490f-a370-f4fecc289a22" />

## Performance

#### Per op SAC
Performance benchmarks for torchtitan Llama3 8b training runs on 4 H100s with per op SAC, using FSDP degree=2, TP degree=2:
- bf16 (vanilla TP): TPS 5161.5, peak memory 50.53 GB
- bf16 (async TP): TPS  5229.5, peak memory 50.68 GB
- float8 tensorwise (vanilla TP): TPS: 5959.5, peak memory: 50.47 GB
- float8 tensorwise (async TP): TPS 5964.5, peak memory 50.47 GB
- float8 rowwise (vanilla TP): TPS: 4962.0, peak memory: 50.55 GB
- float8 rowwise (async TP): TPS 4966.5, peak memory 50.65 GB

#### Full AC
Llama3 70b training runs on 128 H100s with full AC, using FSDP=16, TP=8
- bf16 (vanilla TP): 598 TPS, peak memory 71.51 GB
- bf16 (async TP): TPS  673, peak memory 71.08 (+12.54% TPS vs vanilla TP)
- float8 tensorwise (vanilla TP): 820 TPS, peak memory  55.26 GB
- float8 tensorwise (async TP): 950 TPS, peak memory 55.91 GB (+15.85% TPS vs vanilla TP)
- float8 rowwise (vanilla TP): TPS: 540 TPS, peak memory 71.46 GB
- float8 rowwise (async TP): 560 TPS, peak memory 70.65 GB (+3.7% TPS vs vanilla TP but still unexpectedly lower than bf16)

As you can see, float8 rowwise is working but performance needs to be improved further.

## Other changes
- Added logging so the user will know why fusion failed if it does.
- Remove logic which inserted a reshape node targeting "A scale" to get it to be in 3D like the "A tensor" since it's no longer needed.

## Long term plan
- Add a `scaled_matmul` op in pytorch, which will natively support a 3D+ "A tensor" and allow us to simplify the async TP implementation by avoiding the reshape -> scaled_mm -> reshape pattern and the special handling for it.

## Visualizing fused nodes in graphs for torchtitan training runs

Below are examples of the visualized graph generated by torch compile for torchtitan llama3 8b training runs with per op SAC. These graphs provide additional evidence (beyond the new unit tests added) that the implementation is working correctly.

### bf16

<img width="900" alt="bf16-fusion" src="https://github.com/user-attachments/assets/a3bed917-28eb-4a56-8d6e-2d2bf498385c" />

### float8 with tensorwise scales

<img width="900" alt="tensorwise-node" src="https://github.com/user-attachments/assets/b212ec4a-1899-44de-a4de-18c74e1de68a" />

### float8 with rowwise scales

<img width="900" alt="rowwise" src="https://github.com/user-attachments/assets/ed3354a3-894b-4ec9-86d0-f80364bf3d83" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149247
Approved by: https://github.com/kwen2501
2025-03-27 03:15:30 +00:00
Xuehai Pan
995df34b19 [BE][PYFMT] migrate PYFMT for torch.{distributed,distributions} to ruff format (#144547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547
Approved by: https://github.com/kwen2501
2025-02-28 07:35:56 +00:00
Aaron Orenstein
db4ce78d46 PEP585: More UP006 fixes (#146392)
This should be the final PR before we can enable RUFF UP006.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146392
Approved by: https://github.com/justinchuby, https://github.com/albanD, https://github.com/Skylion007
2025-02-20 06:18:13 +00:00
Luca Wehrstedt
3ee655e4d4 [async-TP] Fix scheduling in matmul+reduce-scatter for 2 ranks (#145846)
There's a sleep that is issued in order to "nudge" CUDA to do the right scheduling decision, but this is issued on iteration number 2. However, when the world size is 2, we never reach that iteration, which led to a suboptimal scheduling.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145846
Approved by: https://github.com/yifuwang
2025-01-30 18:26:34 +00:00
Aaron Orenstein
00ffeca1b1 PEP585 update - torch/distributed (#145164)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145164
Approved by: https://github.com/bobrenjc93
2025-01-21 04:23:29 +00:00
PyTorch MergeBot
6374332d33 Revert "PEP585 update - torch/distributed (#145164)"
This reverts commit 6cb186e279.

Reverted https://github.com/pytorch/pytorch/pull/145164 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing an inductor test ([comment](https://github.com/pytorch/pytorch/pull/145164#issuecomment-2602875679))
2025-01-20 16:46:46 +00:00
Aaron Orenstein
6cb186e279 PEP585 update - torch/distributed (#145164)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145164
Approved by: https://github.com/bobrenjc93
2025-01-20 00:19:01 +00:00
Yu, Guangye
176cde6240 Use torch with statement in torch distributed module (#144951)
# Motivation
In https://github.com/pytorch/pytorch/pull/137678, we help use the device-agnostic APIs to generalize distributed module. As this [comment](https://github.com/pytorch/pytorch/pull/137678#discussion_r1828645683) said, we will use the with statement of `torch.Stream` once https://github.com/pytorch/pytorch/pull/140138 is landed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144951
Approved by: https://github.com/kwen2501, https://github.com/albanD
2025-01-17 01:49:28 +00:00
Aaron Orenstein
d782e46a36 [BE] typing for decorators - library (#138969)
Test Plan: unit tests

Differential Revision: D62302678

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138969
Approved by: https://github.com/zou3519
2025-01-15 17:08:55 +00:00
bobrenjc93
08be9ec312 Migrate from Tuple -> tuple in torch/distributed (#144258)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144258
Approved by: https://github.com/aorenste
2025-01-10 08:34:54 +00:00
Yifu Wang
af190479c8 [fused_all_gather_matmul] use _multimem_all_gather_matmul for small global Ms (#143160)
## Benchmark
M=2048, N=3584, K=8192

baseline (nccl + cublas): 301us
decomp-based async-tp: 354us
comm-aware async-tp: 295us
**multimem_all_gather matmul: 277us**

As M further decreases, the multimem_all_gather approach consistently outperforms the baseline and other approaches (omitted other approaches in the chart as they start to be slower than the baseline):
![image](https://github.com/user-attachments/assets/5811455a-68c9-43fe-9d82-ca488dd77bc1)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143160
Approved by: https://github.com/weifengpy
ghstack dependencies: #142283, #142810, #143159
2024-12-17 01:07:27 +00:00
Yifu Wang
286921b39e [fused_all_gather_matmul] introduce an argument to specify whether the all-gather result needs to be returned (#143159)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143159
Approved by: https://github.com/weifengpy
ghstack dependencies: #142283, #142810
2024-12-17 01:07:27 +00:00
Yifu Wang
810808d97d Enable cutlass-based all-gather matmul when TORCH_SYMM_MEM_ENABLE_NATIVE_ASYNC_TP is set (#142283)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142283
Approved by: https://github.com/weifengpy, https://github.com/Chillee
2024-12-13 10:29:14 +00:00
Yifu Wang
716a06d22c Mark async-tp ops as needs_fixed_stride_order (#142252)
Inductor seems to not respect the input striding of these ops, which is required for fp8 async-tp and has performance implication on other cases.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142252
Approved by: https://github.com/weifengpy
2024-12-07 00:42:27 +00:00
Yifu Wang
5513e2ec35 [SymmetricMemory] use the python version of empty() and rendezvous() for tests and library ops (#142154)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142154
Approved by: https://github.com/weifengpy
2024-12-05 22:09:36 +00:00
Yifu Wang
5a7e147ef3 [SymmetricMemory] introduce user-facing APIs empty() and rendezvous() (#139677)
Previously `SymmetricMemory` only had private pybind APIs:
```python
from torch.distributed._symmetric_memory import _SymmetricMemory
t = _SymmetricMemory.empty_strided_p2p(
    size=(64,),
    stride=(1,),
    dtype=torch.float32,
    device=device,
)
symm_mem_hdl = _SymmetricMemory.rendezvous(t, group_name=group.group_name)
```

This PR introduces user-facing APIs empty() and rendezvous():
```python
import torch.distributed._symmetric_memory as symm_mem
t = symm_mem.empty(64, device="cuda")
symm_mem_hdl = symm_mem.rendezvous(t, group_name=group.group_name)
```

Notable differences compared to the pybind APIs:
- `empty()` now resembles `torch.empty()`:
  - shape can either be an integer sequence or pack
  - no need to/can't specify stride anymore
  - device can either be `torch.device` or string
- `group_name` needs to be specified at rendezvous time as opposed to allocation time. See https://github.com/pytorch/pytorch/pull/139529 for the rationales. I feel the new semantic is superior, hence enforcing it in the public API.
  - Currently, the pybind API still support specifying `group_name` at rendezvous time.

This PR does not change the behavior of the pybind APIs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139677
Approved by: https://github.com/lw
ghstack dependencies: #139529
2024-11-17 20:51:50 +00:00
Yifu Wang
0a0915fb5e [SymmetricMemory] improve the API for stream_write_value32 (#139934)
This PR updates the binding for `stream_write_value32` to be consistent with `memset32` which IMO makes more sense for this type of utilities:
- Changed the API to take a uint32 tensor as argument, instead of a device pointer
- Changed the Python binding to be a static method of `_SymmetricMemory`, instead of a object method
- Use the dispatcher for device dispatching, as opposed to `SymmetricMemory` backends

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139934
Approved by: https://github.com/weifengpy
ghstack dependencies: #139227
2024-11-11 18:49:22 +00:00
PyTorch MergeBot
5f4a21dc58 Revert "[SymmetricMemory] improve the API for stream_write_value32 (#139934)"
This reverts commit 2f3a5a15ef.

Reverted https://github.com/pytorch/pytorch/pull/139934 on behalf of https://github.com/malfet due to Broke distributed tests, see https://github.com/pytorch/pytorch/actions/runs/11770673088/job/32784210441 ([comment](https://github.com/pytorch/pytorch/pull/139934#issuecomment-2468641512))
2024-11-11 17:02:07 +00:00
Yifu Wang
2f3a5a15ef [SymmetricMemory] improve the API for stream_write_value32 (#139934)
This PR updates the binding for `stream_write_value32` to be consistent with `memset32` which IMO makes more sense for this type of utilities:
- Changed the API to take a uint32 tensor as argument, instead of a device pointer
- Changed the Python binding to be a static method of `_SymmetricMemory`, instead of a object method
- Use the dispatcher for device dispatching, as opposed to `SymmetricMemory` backends

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139934
Approved by: https://github.com/weifengpy
ghstack dependencies: #139227
2024-11-11 01:54:35 +00:00
Yifu Wang
1659e241c8 [experimental] async-tp impl with cutlass-based, progress aware kernel (#139227)
This PR introduces the following:

### torch.ops.symm_mem._async_input_mm

`_async_input_mm(Tensor a, Tensor b, Tensor a_chunk_signals, int a_chunk_pivot) -> Tensor`

An mm impl that supports consuming asynchronous input. It guarantees the following rasterization order, and that the corresponding signal arrives before an input chunk is consumed.
```
num_chunks = a_chunks_signals.numel()
for chunk_idx in range(a_chunk_pivot, num_chunks + a_chunk_pivot):
    chunk_idx = chunk_idx % num_chunks
    wait_signal(a_chunk_signals, chunk_idx)
    # Compute output tiles that consumes the input chunk
```

### PersistentAsyncInputScheduler

This is a forked version of PersistentScheduler that supports consuming asynchronous input. This tile scheduler introduces the following arguments:

- `tiles_per_chunk_m` – Specifies the size of an M chunk. Chunks are the granularity at which the asynchronous input becomes ready. It must be an interger multiple of the size of an M tile.
- `chunk_signals` – `chunk_signals[i] == 1` indicates that chunk i is ready. Before returning a work tile, get_current_work() waits for the signal to ensure that the corresponding chunk is ready.
- `tile_idx_pivot_m` – After applying swizzling, apply `pivot(m) => (m + tile_idx_pivot_m) % tiles_m` to `m`. In a distributed setting, this allows different ranks to process different m indices at the same time, thus avoiding communication hotspots.

Note that this scheduler currently only supports the `KernelTmaWarpSpecializedCooperative` kernel schedule. This is enforced via the template argument `KernelSchedule`.

Usage:
```
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
   Shape<int, int, int, int>,
   CollectiveMainloop,
   CollectiveEpilogue,
   cutlass::gemm::PersistentAsyncInputScheduler<KernelSchedule>>;
```

### _fused_all_gather_matmul_native
An ag-mm impl that combines `torch.ops.symm_mem._async_input_mm` and progress-aware all-gather. This is not yet enabled via the async-tp passes. We will use it as a backend to optimize the current decomposition-based async-tp impl.

## Benchmarks

### 4096x3584x8192
- cublas + nccl: 539us
- decomp-based async-tp w/o cuda graph: 694us
- decomp-based async-tp w/ cuda graph: 478us
- new cutlass kernel: 408us

<img width="478" alt="image" src="https://github.com/user-attachments/assets/39f316ab-36c5-4b41-af77-07854a385dfc">

### 2048x3584x8192
- cublas + nccl: 301us
- decomp-based async-tp w/o cuda graph: 687us
- decomp-based async-tp w/ cuda graph: 356us
- new cutlass kernel: 276us

<img width="441" alt="image" src="https://github.com/user-attachments/assets/9e23ce21-863b-43dd-a562-fb05d3a5a144">

## Next Steps
- Add tuning logic
- Use `_fused_all_gather_matmul_native` as a backend for the decomp-based async-tp impl

Differential temp Revision: [D65623152](https://our.internmc.facebook.com/intern/diff/D65623152)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139227
Approved by: https://github.com/weifengpy, https://github.com/Chillee
2024-11-08 23:28:25 +00:00
PyTorch MergeBot
36e0f119d0 Revert "[experimental] async-tp impl with cutlass-based, progress aware kernel (#139227)"
This reverts commit 5203138483.

Reverted https://github.com/pytorch/pytorch/pull/139227 on behalf of https://github.com/yifuwang due to Need to address internal build failure D65605027 ([comment](https://github.com/pytorch/pytorch/pull/139227#issuecomment-2463204467))
2024-11-07 21:01:36 +00:00
Yifu Wang
5203138483 [experimental] async-tp impl with cutlass-based, progress aware kernel (#139227)
This PR introduces the following:

### torch.ops.symm_mem._async_input_mm

`_async_input_mm(Tensor a, Tensor b, Tensor a_chunk_signals, int a_chunk_pivot) -> Tensor`

An mm impl that supports consuming asynchronous input. It guarantees the following rasterization order, and that the corresponding signal arrives before an input chunk is consumed.
```
num_chunks = a_chunks_signals.numel()
for chunk_idx in range(a_chunk_pivot, num_chunks + a_chunk_pivot):
    chunk_idx = chunk_idx % num_chunks
    wait_signal(a_chunk_signals, chunk_idx)
    # Compute output tiles that consumes the input chunk
```

### PersistentAsyncInputScheduler

This is a forked version of PersistentScheduler that supports consuming asynchronous input. This tile scheduler introduces the following arguments:

- `tiles_per_chunk_m` – Specifies the size of an M chunk. Chunks are the granularity at which the asynchronous input becomes ready. It must be an interger multiple of the size of an M tile.
- `chunk_signals` – `chunk_signals[i] == 1` indicates that chunk i is ready. Before returning a work tile, get_current_work() waits for the signal to ensure that the corresponding chunk is ready.
- `tile_idx_pivot_m` – After applying swizzling, apply `pivot(m) => (m + tile_idx_pivot_m) % tiles_m` to `m`. In a distributed setting, this allows different ranks to process different m indices at the same time, thus avoiding communication hotspots.

Note that this scheduler currently only supports the `KernelTmaWarpSpecializedCooperative` kernel schedule. This is enforced via the template argument `KernelSchedule`.

Usage:
```
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
   Shape<int, int, int, int>,
   CollectiveMainloop,
   CollectiveEpilogue,
   cutlass::gemm::PersistentAsyncInputScheduler<KernelSchedule>>;
```

### _fused_all_gather_matmul_native
An ag-mm impl that combines `torch.ops.symm_mem._async_input_mm` and progress-aware all-gather. This is not yet enabled via the async-tp passes. We will use it as a backend to optimize the current decomposition-based async-tp impl.

## Benchmarks

### 4096x3584x8192
- cublas + nccl: 539us
- decomp-based async-tp w/o cuda graph: 694us
- decomp-based async-tp w/ cuda graph: 478us
- new cutlass kernel: 408us

<img width="478" alt="image" src="https://github.com/user-attachments/assets/39f316ab-36c5-4b41-af77-07854a385dfc">

### 2048x3584x8192
- cublas + nccl: 301us
- decomp-based async-tp w/o cuda graph: 687us
- decomp-based async-tp w/ cuda graph: 356us
- new cutlass kernel: 276us

<img width="441" alt="image" src="https://github.com/user-attachments/assets/9e23ce21-863b-43dd-a562-fb05d3a5a144">

## Next Steps
- Add tuning logic
- Use `_fused_all_gather_matmul_native` as a backend for the decomp-based async-tp impl

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139227
Approved by: https://github.com/weifengpy, https://github.com/Chillee
2024-11-07 03:43:12 +00:00
Yifu Wang
421473c234 get_symm_mem_workspace(): print helpful error during graph capture (#138028)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138028
Approved by: https://github.com/weifengpy
2024-10-30 18:11:09 +00:00
Yifu Wang
c69f4518ec [SymmetricMemory] fix a race condition in _pipelined_produce_and_all2all that can cause correctness issues for very small chunk_producers (#138126)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138126
Approved by: https://github.com/lessw2020
2024-10-17 01:05:41 +00:00
Yifu Wang
5d5783a263 Improve the scheduling of _pipelined_multi_all_gather_and_consume (#137850)
```
Parallelization strategy: after each rank copies its shard into its local
p2p buffer, every rank issues independent p2p copy -> shard_consumer
sequences to two streams. In addition to computation/communication
overlapping, the strategy allows for computation/computation overlapping,
greatly reducing quantization inefficiency.

Notation:
- "mv" for the copy to local buffer
- "cp" for p2p copies
- "b" for barriers

Constraints:
- The GPU scheduler may or may not overlap "mv" with the first shard_consumer.
- "cp" from different streams cannot overlap.

Ideal scenario 0 - "mv" overlaps with the first shard_consumer:

stream 0: [ shard_consumer ][ cp ][ shard_consumer ]
stream 1: [ mv ][b][ cp ][ shard_consumer ]

Ideal scenario 1 - "mv" is scheduled before the first shard_consumer:

stream 0:       [ shard_consumer ][ cp ][ shard_consumer ]
stream 1: [ mv ][b][ cp ][ shard_consumer ]

Suboptimal scenario 0 - "mv" is scheduled after the first shard_consumer:

stream 0: [ shard_consumer ]               [ cp ][ shard_consumer ]
stream 1:                   [ mv ][b][ cp ][ shard_consumer ]

Suboptimal scenario 0 - "b" is scheduled after the first shard_consumer:

stream 0:       [ shard_consumer ]         [ cp ][ shard_consumer ]
stream 1: [ mv ]                  [b][ cp ][ shard_consumer ]

We haven't yet figured out a way to ensure "mv" and "b" are either
overlapped with or scheduled before the first shard_consumer. Thus, to
prevent suboptimal scenarios, we are giving up the chance to overlap "mv"
and "b" with the first shard_consumer for now.
```

This PR improves the scheduling for mm kernels with high SM utilization. The GPU scheduler tends to not overlap local DtoD copies with such kernels, which leads to suboptimal scheduling. The following is an example of pipelining PyTorch's cutlass-based, row-wise scaling fp8 kernel:

Before this PR:
<img width="298" alt="image" src="https://github.com/user-attachments/assets/81e0a7f4-18ee-47c6-b258-04fdaca7a6a2">

With this PR:
<img width="253" alt="image" src="https://github.com/user-attachments/assets/982de5a8-da1e-4a8f-b67e-c9c869b0a77f">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137850
Approved by: https://github.com/weifengpy
ghstack dependencies: #137643, #137738, #137805, #137836
2024-10-15 21:35:14 +00:00
Yifu Wang
2ae1a4caa1 Improve the scheduling of _pipelined_produce_and_all2all (#137836)
```
Parallelization strategy: every rank issues independent compute
-> barrier -> p2p copy sequences on two streams. In addition to
computation/communication overlapping, the strategy allows for
computation/computation overlapping, greatly reducing
quantization inefficiency.

Ideally, stream activities would look like this ("b" for
barriers, "cp" for p2p copies):

[rank 0]
stream 0:         [  chunk_producer  ][b][ cp ][  chunk_producer ][b][ cp ]
stream 1: [  chunk_producer  ][b][ cp ][  chunk_producer  ][b][ cp ]

[rank 1]
stream 0:         [  chunk_producer  ][b][ cp ][  chunk_producer ][b][ cp ]
stream 1: [  chunk_producer  ][b][ cp ][  chunk_producer  ][b][ cp ]

Note that the barriers synchronize streams with the same ID
across ranks. They don't synchronize streams on the same rank.

Since the work on both streams is independent, there's no
guarantee that the chunk_producer from stream 0 or stream 1 will
be scheduled first. If there is a scheduling mismatch across
ranks, the barrier forces all ranks to wait for the slowest.

When scheduling mismatches occur among ranks, the stream
activities might look like this (note that p2p copies from
different streams cannot overlap with each other):

[rank 0]
stream 0: [  chunk_producer  ][b        ][ cp ][  chunk_producer ][b       ][ cp ]
stream 1:         [  chunk_producer  ][b]      [ cp ][  chunk_producer  ][b]      [ cp ]

[rank 1]
stream 0:         [  chunk_producer  ][b]      [ cp ][  chunk_producer  ][b]      [ cp ]
stream 1: [  chunk_producer  ][b        ][ cp ][  chunk_producer  ][b      ][ cp ]

To prevent this, we need to ensure that the chunk_producer on
stream 1 gets scheduled first on every rank. Without access to
the underlying kernels, CUDA offers no API to control the
scheduling order of two independent, overlapping kernels. Our
solution is to issue a small sleep kernel in stream 0. The sleep
duration is insignificant, but having an extra task in stream 0
will almost guarantee that the chunk_producer on stream 1 gets
scheduled first. Once the first chunk_producer is scheduled in
the correct order, there's very little room for the scheduling
order of subsequent kernels to be inconsistent across ranks.
```

Currently, we perform stream synchronization to ensure scheduling order. The stream synchronization has no bearing on correctness, but prevents inconsistent scheduling orders across ranks.

Without the stream synchronization, ranks may have inconsistent scheduling order, and the barriers cause all ranks to wait for the slowest rank:
<img width="379" alt="image" src="https://github.com/user-attachments/assets/ffb97e76-7e19-4449-b121-83c32ec3e91d">

With stream synchronization, the inconsistent scheduling order issue is addressed, but we lose compute/compute overlapping (this is the state before this PR):
<img width="378" alt="image" src="https://github.com/user-attachments/assets/4cb76246-625f-4fc1-b49a-823ae46d3f23">

With this PR, we get both consistent scheduling order across ranks and compute/compute overlap:
<img width="327" alt="image" src="https://github.com/user-attachments/assets/51ab1bdc-4f60-46e0-b53c-6d208e2d4888">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137836
Approved by: https://github.com/weifengpy
ghstack dependencies: #137643, #137738, #137805
2024-10-15 21:35:14 +00:00
Yifu Wang
ef541c1a65 [fused_all_gather_scaled_matmul] support rowwise scaling (#137805)
This PR add support for `A_scale` to be row-wise scale. The op can automatically detect whether the row-wise scale is sharded or replicated. When the row-wise scale is sharded, the op would all-gather the scale in a pipelined fashion.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137805
Approved by: https://github.com/weifengpy
ghstack dependencies: #137643, #137738
2024-10-15 21:35:14 +00:00
Yifu Wang
05edaeaded [fused_scaled_matmul_reduce_scatter] support rowwise scaling (#137738)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137738
Approved by: https://github.com/Chillee, https://github.com/weifengpy
ghstack dependencies: #137643
2024-10-15 21:35:14 +00:00
Yifu Wang
38114ec860 [async-tp] fix a race condition that can cause silent correctness issue (#137199)
Details described in https://github.com/pytorch/pytorch/issues/137171:

![image](https://github.com/user-attachments/assets/8247b4f1-7805-4585-9d72-05e9475f218b)

Fix: we introduce the following invariants in `_pipelined_all_gather_and_consume` and `_pipelined_produce_and_all2all`:
- Before any stream writes to/reads from p2p buffers, perform a barrier on channel 0 on the launch stream.
- After all streams completed writing to/reading from p2p buffers, perform a barrier on channel 0 on the launch stream.

NOTE: This fix only focuses on addressing the race condition. Some barriers are exposed, which can be hidden by computation, and we'll optimize them in subsequent PRs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137199
Approved by: https://github.com/weifengpy
2024-10-03 10:42:37 +00:00
Yifu Wang
ea42027e0e [micro_pipeline_tp] support all _scaled_mm args (#131984)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131984
Approved by: https://github.com/weifengpy
2024-08-05 21:44:37 +00:00
Xuehai Pan
b25ef91bf1 [BE][Easy][18/19] enforce style for empty lines in import segments in torch/d*/ (#129770)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129770
Approved by: https://github.com/wconstab
2024-08-01 04:22:50 +00:00
Yifu Wang
5a33657b31 [micro_pipeline_tp] implement the pass for fused_scaled_matmul_reduce_scatter (#131951)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131951
Approved by: https://github.com/weifengpy
2024-07-30 23:02:49 +00:00
PyTorch MergeBot
a3ba405871 Revert "[BE] typing for decorators - library (#131570)"
This reverts commit 5731b486c8.

Reverted https://github.com/pytorch/pytorch/pull/131570 on behalf of https://github.com/clee2000 due to same as https://github.com/pytorch/pytorch/pull/131572#issuecomment-2254328359 but I clicked the wrong link by accident.  This is where it actually starts ([comment](https://github.com/pytorch/pytorch/pull/131568#issuecomment-2254330781))
2024-07-28 03:43:39 +00:00
Yifu Wang
a8a9882899 Implement fused_scaled_matmul_reduce_scatter for async-TP (#131950)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131950
Approved by: https://github.com/weifengpy
ghstack dependencies: #131410, #131831, #131832, #131833
2024-07-28 03:39:12 +00:00
Yifu Wang
0538a69a8d [micro_pipeline_tp] support all-gather -> _scaled_mm (#131833)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131833
Approved by: https://github.com/weifengpy
ghstack dependencies: #131410, #131831, #131832
2024-07-28 03:39:11 +00:00
Yifu Wang
93a4671746 Add out_dtypes to fused_all_gather_scaled_matmul's args (#131831)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131831
Approved by: https://github.com/weifengpy
ghstack dependencies: #131410
2024-07-27 11:07:43 +00:00
Aaron Orenstein
5731b486c8 [BE] typing for decorators - library (#131570)
See #131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131570
Approved by: https://github.com/oulgen, https://github.com/zou3519
ghstack dependencies: #131568, #131569
2024-07-25 22:24:19 +00:00
Yifu Wang
161c18ed0b SymmetricMemory-based, low contention intra-node all-gather and reduce-scatter (#130583)
```python
# NOTE [low-contention collectives]
# When a collective is overlapped with abundant compute, it makes sense to
# prioritize reducing the contention between the collective and the overlapped
# compute, even at the cost of a slightly slower collective.
#
# Common collective implementations (e.g., NCCL without user buffer
# registration) optimize for throughput with no ambient compute. However, such
# implementations may not be optimal when they are overlapped with compute:
# - These impls typically fuse the entire collective into a single kernel and
# reserve SM resources based on the most demanding portion of the collective,
# even when a large portion of the collective does not require this much
# resource.
# - These implementations typically fuse the entire collective into a single
# kernel and reserve SM resources based on the most demanding portion of the
# collective, even when a large portion of the collective does not require this
# much resource.
# - These implementations often use SM-based P2P copy as opposed to copy
# engine-based P2P copy. Copy engine-based P2P copy may not have a significant
# advantage when there's no ambient compute. However, it may significantly
# improve overall resource utilization in the presence of ambient compute.
#
# When overlapped with intensive compute (e.g., persistent matmul kernels), the
# SM-usage of a collective can lead to inefficient overlapping.
#
# Low-contention collectives achieve their goals with the following strategies:
# - Use copy engine-based copy whenever possible.
# - Break down portions of a collective with different resource requirements
# into multiple kernels. This improves the overlapping efficiency at the cost
# of additional launching overhead.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130583
Approved by: https://github.com/weifengpy
2024-07-23 23:37:48 +00:00
Aaron Orenstein
5a0068cc69 [BE] mypy: disallow untyped decorators (#131428)
Untyped decorators strip the types from their decorated function so even if the underlying function is fully typed then callers to it don't get any benefit from type annotations.

Step 1 - Enable the error and override in all the offending files.

#131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131428
Approved by: https://github.com/justinchuby, https://github.com/oulgen
2024-07-23 21:50:55 +00:00
Yifu Wang
0468f2616a [SymmetricMemory] make sure different subgroups with the same name use different store prefixes (#130756)
This fixes a race condition in which different subgroups with the same name on the same host would use the same store.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130756
Approved by: https://github.com/Chillee
2024-07-16 20:21:05 +00:00
Yifu Wang
db3a641b71 Implement operator for micro-pipelined all-gather -> _scaled_mm (#129289)
This PR implements `torch.ops.symm_mem.fused_all_gather_scaled_matmul`. It's similar to `torch.ops.symm_mem.fused_all_gather_matmul`, except that it takes scales and calls ` _scaled_mm`.

[Profiling Trace vs. Baseline](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/yifu_tmp0gmg1f2_) (FB internal only)

Co-authored-by: Will Feng <yf225@cornell.edu>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129289
Approved by: https://github.com/Chillee, https://github.com/weifengpy, https://github.com/drisspg
2024-07-15 21:48:35 +00:00
Yifu Wang
bbd47f7b2f Remove ProcessGroupCudaP2P and change async-TP to use SymmetricMemory (#128762)
This PR removes `ProcessGroupCudaP2P` and changes async-TP to use `SymmetricMemory`. The async-TP implementation is still workspace-based, but it now doesn't require a buffer size to be specified upfront.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128762
Approved by: https://github.com/wanchaol
2024-06-25 22:32:21 +00:00