mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Async TP] More robust support for rowwise scales when fusing matmul reduce-scatter (#149247)
Part of https://github.com/pytorch/torchtitan/issues/866
## Context
- Async TP needs to support the "reshape -> scaled_mm -> reshape" pattern because scaled mm only supports 2D input tensors and 2D scales.
- (a,b,c) => (a*b,c)
- (a\*b,c) @ (c,d) = (a\*b,d)
- (a\*b,d) => (a,b,d)
- Currently the implementation does not support scaled mm with rowwise scales **for all cases** of the reshape -> scaled_mm -> reshape pattern. The minimal example of this pattern is confirmed to work via this [unit test](00a2c68f67/test/distributed/tensor/parallel/test_micro_pipeline_tp.py (L406)), but more involved e2e examples in torchtitan fail silently (more context in final bullet point).
- Previously, the "A tensor" **node** referenced in the async TP graph manipulation code is the 3D+ node before the reshape, but the "A_scale" node is the 2d node from after the reshape, so they are incompatible.
- I previously implemented a simpler solution to this problem in https://github.com/pytorch/pytorch/pull/148001, with a [unit test](https://github.com/pytorch/pytorch/pull/148001/files#diff-115f1d0852382c9b58f22640d80999d879b33618e5f6c633fc9e4d0ca9781cecR406) confirming the fused node is indeed in the graph for the minimal example of the reshape->mm->reshape pattern. I also confirmed via manual e2e testing w/ torchtitan that the crash I was fixing no longer occurred. However, it turns out due to this [bug in torchtitan](https://github.com/pytorch/torchtitan/issues/866) it was causing async TP to fail silently and fall back to vanilla TP, hiding the fact that this original solution fixed the crash but the fusion would not occur for rowwise scales. Thus, more robust solution is needed to support all cases.
## Solution TL;DR
- Use the 2D 'A' tensor and corresponding 2D scales as input to the fused_matmul_reduce_scatter implementation, instead of the 3D+ tensor/scales.
- Track the "pre mm reshape" and "post mm reshape" separately, to be referenced in the `fused_scaled_matmul_reduce_scatter` implementation, to update the scatter dim through the pre-mm reshape, and apply the post-mm reshape before applying the reduce scatter and returning the output tensor.
- Separate the `fused_matmul_reduce_scatter` and the `fused_scaled_matmul_reduce_scatter` code paths, to simplify them both.
- By fixing the bug in torchtitan (PR https://github.com/pytorch/torchtitan/pull/965) and implementing support for rowwise scales in pytorch in this PR, together these changes will solve the problem of how to support rowwise scales with all types of AC.
## Additional details for reviewers
To use the 2D A tensor while also supporting the "reshape -> mm -> reshape" pattern, the following other changes were needed:
- Track the pre-mm reshape, as it will affect the scatter dim used in the fused_matmul_reduce_scatter impementation.
- Track the post-mm reshape, as it will affect the output shape used in the fused_matmul_reduce_scatter impementation
- Based on the pre-mm reshape and the original scatter dim, calculate the new scatter dim for the 2D tensor. This is needed because during the pipelined producer mm implementation, the scatter dim is moved to dim 0 (so it can be sharded along the first dim and then get chunks to do mm ops on by indexing into the first dim), then moved back to it's original place before the reduce-scatter.
- Use the tracked post-mm reshape to reshape the stacked partial 2D outputs of the mm ops into 3D outputs needed for 1) the reduce-scatter w/ the original scatter dim, and 2) the expected output shape to prevent shape errors with subsequent ops.
## Test plan
- All existing unit tests passing.
- Expand unit tests for rowwise scales to test more scatter dims
- Added unit tests enforcing that async TP fails fast / throws an error if it fails to perform any fusions. Previously it just "failed silently" (fell back to vanilla TP without the user knowing) which has led to confusion, so this will improve the UX.
- Compared loss curves of bf16 vs float8 w/ rowwise scales to confirm integrity of numerics
- Confirmed via manual testing with torchtitan and inspecting the compile graph that the fusion is working as intended for:
- bfloat16
- float8 with tensorwise scales
- float8 with rowwise scales
## Loss curves
Loss curves are virtually identical for bf16 + vanilla TP versus float8 with rowwise scales + async TP:
<img width="1017" alt="loss_async_tp" src="https://github.com/user-attachments/assets/4995db78-7012-490f-a370-f4fecc289a22" />
## Performance
#### Per op SAC
Performance benchmarks for torchtitan Llama3 8b training runs on 4 H100s with per op SAC, using FSDP degree=2, TP degree=2:
- bf16 (vanilla TP): TPS 5161.5, peak memory 50.53 GB
- bf16 (async TP): TPS 5229.5, peak memory 50.68 GB
- float8 tensorwise (vanilla TP): TPS: 5959.5, peak memory: 50.47 GB
- float8 tensorwise (async TP): TPS 5964.5, peak memory 50.47 GB
- float8 rowwise (vanilla TP): TPS: 4962.0, peak memory: 50.55 GB
- float8 rowwise (async TP): TPS 4966.5, peak memory 50.65 GB
#### Full AC
Llama3 70b training runs on 128 H100s with full AC, using FSDP=16, TP=8
- bf16 (vanilla TP): 598 TPS, peak memory 71.51 GB
- bf16 (async TP): TPS 673, peak memory 71.08 (+12.54% TPS vs vanilla TP)
- float8 tensorwise (vanilla TP): 820 TPS, peak memory 55.26 GB
- float8 tensorwise (async TP): 950 TPS, peak memory 55.91 GB (+15.85% TPS vs vanilla TP)
- float8 rowwise (vanilla TP): TPS: 540 TPS, peak memory 71.46 GB
- float8 rowwise (async TP): 560 TPS, peak memory 70.65 GB (+3.7% TPS vs vanilla TP but still unexpectedly lower than bf16)
As you can see, float8 rowwise is working but performance needs to be improved further.
## Other changes
- Added logging so the user will know why fusion failed if it does.
- Remove logic which inserted a reshape node targeting "A scale" to get it to be in 3D like the "A tensor" since it's no longer needed.
## Long term plan
- Add a `scaled_matmul` op in pytorch, which will natively support a 3D+ "A tensor" and allow us to simplify the async TP implementation by avoiding the reshape -> scaled_mm -> reshape pattern and the special handling for it.
## Visualizing fused nodes in graphs for torchtitan training runs
Below are examples of the visualized graph generated by torch compile for torchtitan llama3 8b training runs with per op SAC. These graphs provide additional evidence (beyond the new unit tests added) that the implementation is working correctly.
### bf16
<img width="900" alt="bf16-fusion" src="https://github.com/user-attachments/assets/a3bed917-28eb-4a56-8d6e-2d2bf498385c" />
### float8 with tensorwise scales
<img width="900" alt="tensorwise-node" src="https://github.com/user-attachments/assets/b212ec4a-1899-44de-a4de-18c74e1de68a" />
### float8 with rowwise scales
<img width="900" alt="rowwise" src="https://github.com/user-attachments/assets/ed3354a3-894b-4ec9-86d0-f80364bf3d83" />
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149247
Approved by: https://github.com/kwen2501
This commit is contained in:
parent
114d404b07
commit
ae29f054f5
|
|
@ -16,6 +16,7 @@ from torch._inductor.fx_passes.post_grad import remove_noop_ops, view_to_reshape
|
|||
from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code
|
||||
from torch.distributed._functional_collectives import (
|
||||
all_gather_tensor,
|
||||
all_reduce,
|
||||
reduce_scatter_tensor,
|
||||
)
|
||||
from torch.distributed._symmetric_memory import _test_mode
|
||||
|
|
@ -401,7 +402,7 @@ class MicroPipelineTPTest(TestCase):
|
|||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@parametrize("scatter_dim", [2])
|
||||
@parametrize("scatter_dim", [0, 1, 2])
|
||||
@fresh_inductor_cache()
|
||||
def test_fuse_scaled_matmul_reduce_scatter_rowwise_scales_reshape_mm_reshape(
|
||||
self, scatter_dim
|
||||
|
|
@ -432,11 +433,11 @@ class MicroPipelineTPTest(TestCase):
|
|||
C = C.view(*orig_shape[:-1], C.shape[-1])
|
||||
return reduce_scatter_tensor(C, "sum", scatter_dim, group)
|
||||
|
||||
A = torch.rand(1, 16, 32, device="cuda").to(torch.float8_e4m3fn)
|
||||
A = torch.rand(2, 16, 32, device="cuda").to(torch.float8_e4m3fn)
|
||||
B = torch.rand(64, 32, device="cuda").to(torch.float8_e4m3fn).T
|
||||
|
||||
# A_scale = rowwise scales
|
||||
A_scale = torch.full((1, 16, 1), 0.1, device="cuda")
|
||||
A_scale = torch.full((2, 16, 1), 0.1, device="cuda")
|
||||
|
||||
# B_scale = rowwise scales transposed for A @ B^T
|
||||
B_scale = torch.full((1, 64), 0.1, device="cuda")
|
||||
|
|
@ -462,6 +463,73 @@ class MicroPipelineTPTest(TestCase):
|
|||
self.assertIn("fused_scaled_matmul_reduce_scatter", code)
|
||||
self.assertNotIn("reduce_scatter_tensor", code)
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_inductor_cache()
|
||||
def test_no_all_gathers_or_reduce_scatters(self):
|
||||
group = dist.group.WORLD
|
||||
|
||||
def no_matching_pattern(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Performs some ops which will not have any all-gather-matmul or matmul-reduce-scatter patterns.
|
||||
"""
|
||||
C = A * B
|
||||
return all_reduce(C, "sum", group)
|
||||
|
||||
A = torch.rand(2, 16, 32, device="cuda").to(torch.bfloat16)
|
||||
B = torch.rand(16, 32, device="cuda").to(torch.bfloat16)
|
||||
|
||||
gm = _make_post_grad_fx(no_matching_pattern, A, B)
|
||||
|
||||
with _test_mode():
|
||||
self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
"async TP found no matching all-gather/reduce-scatter patterns for fusion",
|
||||
micro_pipeline_tp_pass,
|
||||
gm.graph,
|
||||
)
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_inductor_cache()
|
||||
def test_unsuccessful_fusion(self):
|
||||
group = dist.group.WORLD
|
||||
scatter_dim = 0
|
||||
|
||||
def no_matching_pattern(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Performs 'reshape -> reciprocal -> mm -> reshape -> reduce scatter' pattern,
|
||||
so the extra 'reciprocal' op in the middle should cause pattern matching to fail.
|
||||
"""
|
||||
out_shape = [*A.shape[:-1], B.shape[-1]]
|
||||
A = A.reshape(-1, A.shape[-1])
|
||||
|
||||
# insert extra op after reshape that will cause pattern matching to fail
|
||||
A = torch.reciprocal(A)
|
||||
|
||||
C = A @ B
|
||||
C = C.view(out_shape)
|
||||
return reduce_scatter_tensor(C, "sum", scatter_dim, group)
|
||||
|
||||
A = torch.rand(2, 16, 32, device="cuda").to(torch.bfloat16)
|
||||
B = torch.rand(16, 32, device="cuda").to(torch.bfloat16).T
|
||||
|
||||
gm = _make_post_grad_fx(no_matching_pattern, A, B)
|
||||
|
||||
with _test_mode():
|
||||
self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
"no successful fusions of matul-reduce-scatters",
|
||||
micro_pipeline_tp_pass,
|
||||
gm.graph,
|
||||
)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@parametrize("shard_dim", [0, 1])
|
||||
@fresh_inductor_cache()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
import operator
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from math import prod
|
||||
from typing import Any, cast, Optional
|
||||
|
||||
import torch
|
||||
|
|
@ -20,6 +22,7 @@ from ..pattern_matcher import (
|
|||
)
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
aten = torch.ops.aten
|
||||
patterns = PatternMatcherPass()
|
||||
|
||||
|
|
@ -292,6 +295,8 @@ class _Matmul:
|
|||
arg_ancestor_nodes: OrderedSet[torch.fx.Node] = field(init=False)
|
||||
A_node: torch.fx.Node
|
||||
B_node: torch.fx.Node
|
||||
pre_mm_reshape: Optional[torch.fx.Node]
|
||||
post_mm_reshape: Optional[torch.fx.Node]
|
||||
|
||||
def __post_init__(self):
|
||||
assert len(self.nodes) in (1, 3)
|
||||
|
|
@ -353,8 +358,12 @@ class _Matmul:
|
|||
mm_node = match[0] if len(match) == 1 else match[1]
|
||||
return _Matmul(
|
||||
nodes=match,
|
||||
A_node=cast(torch.fx.Node, match[0].args[0]),
|
||||
B_node=cast(torch.fx.Node, mm_node.args[1]),
|
||||
A_node=cast("torch.fx.Node", match[0].args[0]),
|
||||
B_node=cast("torch.fx.Node", mm_node.args[1]),
|
||||
# _Matmul handles reshapes via custom graph manipulation logic, see `replace_with()` method.
|
||||
# TOOO: explore unifying the _Matmul and _ScaledMatmul approaches to handling reshapes.
|
||||
pre_mm_reshape=None,
|
||||
post_mm_reshape=None,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -366,6 +375,8 @@ class _ScaledMatmul(_Matmul):
|
|||
result_scale_node: Optional[torch.fx.Node]
|
||||
out_dtype: Optional[torch.dtype]
|
||||
use_fast_accum: bool
|
||||
pre_mm_reshape: Optional[torch.fx.Node]
|
||||
post_mm_reshape: Optional[torch.fx.Node]
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
|
@ -379,104 +390,23 @@ class _ScaledMatmul(_Matmul):
|
|||
aten._scaled_mm.default,
|
||||
aten.reshape.default,
|
||||
)
|
||||
mm_node = match[0] if len(match) == 1 else match[1]
|
||||
|
||||
def get_arg(node: torch.fx.Node, idx: int, default: Any) -> Any:
|
||||
if idx >= len(node.args):
|
||||
return default
|
||||
return node.args[idx]
|
||||
|
||||
def insert_reshape_op(node: torch.fx.Node):
|
||||
"""
|
||||
Given a reciprocal node with a parent reshape node,
|
||||
insert a reshape node after the reciprocal node which reshapes
|
||||
the reciprocal output back to the original shape before the first reshape.
|
||||
|
||||
Before:
|
||||
reshape (a,bc,) to (a*b,c) -> reciprocal
|
||||
|
||||
After:
|
||||
reshape (a,bc,) to (a*b,c) -> reciprocal -> reshape (a*b,c) to (a,b,c)
|
||||
|
||||
Returns the new reshape node.
|
||||
"""
|
||||
# ensure the given node matches the pattern described in the docstring
|
||||
assert node.target == aten.reciprocal.default, (
|
||||
"Node must be a aten.reciprocal.default op"
|
||||
)
|
||||
assert len(node.all_input_nodes) == 1, "Node must have exactly one parent"
|
||||
|
||||
parent_node = node.all_input_nodes[0]
|
||||
assert parent_node.target == aten.reshape.default, (
|
||||
"Parent node must be a aten.reshape.default op"
|
||||
)
|
||||
assert len(parent_node.all_input_nodes) == 1, (
|
||||
"Parent node must have exactly one input node"
|
||||
)
|
||||
|
||||
parent_input_node = parent_node.all_input_nodes[0]
|
||||
parent_input_shape = list(_get_tensor(parent_input_node).shape)
|
||||
|
||||
# insert reshape back to shape from before the parent reshape op
|
||||
graph = node.graph
|
||||
with graph.inserting_after(node):
|
||||
reshape_node = graph.call_function(
|
||||
aten.reshape.default, (node, parent_input_shape)
|
||||
)
|
||||
|
||||
# ensure all users of original node (except the reshape node) now use the reshaped node instead
|
||||
node_users = list(node.users)
|
||||
for user in node_users:
|
||||
if user != reshape_node:
|
||||
user.replace_input_with(node, reshape_node)
|
||||
|
||||
return reshape_node
|
||||
|
||||
# Use mm_node with 2D args for both A and B, even if this is a "reshape -> mm -> reshape" pattern.
|
||||
# We will store the reshapes in pre_mm_reshape and post_mm_reshape, to be referenced later to
|
||||
# produce the correct output shapes, reduce-scatter along the correct dimensions, etc.
|
||||
is_reshape_mm_reshape_pattern = match[0].target == aten.reshape.default
|
||||
mm_node = match[1] if is_reshape_mm_reshape_pattern else match[0]
|
||||
|
||||
# `A_node` is pulled directly from match rather than `mm_node` because it needs to handle
|
||||
# both of the following cases:
|
||||
#
|
||||
# Case 1: single node match (mm):
|
||||
# - match[0].args[0] will be the "A tensor" node of scaled_mm
|
||||
# - Has 2D shape
|
||||
#
|
||||
# Case 2: 3 node match (reshape -> mm -> reshape)
|
||||
# - match[0].args[0] will be the "A tensor" input to the reshape op
|
||||
# - Has 3D+ shape
|
||||
A_node = cast(torch.fx.Node, match[0].args[0])
|
||||
B_node = cast(torch.fx.Node, mm_node.args[1])
|
||||
A_scale_node = cast(torch.fx.Node, mm_node.args[2])
|
||||
B_scale_node = cast(torch.fx.Node, mm_node.args[3])
|
||||
|
||||
A_ndim = _get_tensor(A_node).ndim
|
||||
A_scale_ndim = _get_tensor(A_scale_node).ndim
|
||||
is_reciprocal_with_reshape_parent = (
|
||||
A_scale_node.target == aten.reciprocal.default
|
||||
and len(A_scale_node.all_input_nodes) == 1
|
||||
and A_scale_node.all_input_nodes[0].target == aten.reshape.default
|
||||
)
|
||||
is_tensorwise_scaling = A_scale_ndim <= 1
|
||||
|
||||
# This is a temporary workaround to handle the reshape -> scaled_mm -> reshape
|
||||
# pattern when scales are row-wise, and have been reshaped along with the target
|
||||
# tensor. See https://github.com/pytorch/pytorch/pull/148001 for details.
|
||||
#
|
||||
# If tensor dim does not match scale dim, check if the scale node follows
|
||||
# the "reshape -> reciprocal" pattern. If so, we can insert a reshape op after
|
||||
# the reciprocal, to reshape the reciprocal back to the original shape before
|
||||
# the first reshape op.
|
||||
#
|
||||
# TODO: remove this workaround once torch._scaled_matmul exists and can be used
|
||||
# to implement a more robust long-term support for 3D+ scaled matmuls.
|
||||
if (
|
||||
is_reshape_mm_reshape_pattern
|
||||
and A_ndim != A_scale_ndim
|
||||
and not is_tensorwise_scaling
|
||||
and is_reciprocal_with_reshape_parent
|
||||
):
|
||||
A_scale_node = insert_reshape_op(A_scale_node)
|
||||
pre_mm_reshape = match[0] if is_reshape_mm_reshape_pattern else None
|
||||
post_mm_reshape = match[-1] if is_reshape_mm_reshape_pattern else None
|
||||
A_node = cast("torch.fx.Node", mm_node.args[0])
|
||||
B_node = cast("torch.fx.Node", mm_node.args[1])
|
||||
A_scale_node = cast("torch.fx.Node", mm_node.args[2])
|
||||
B_scale_node = cast("torch.fx.Node", mm_node.args[3])
|
||||
|
||||
return _ScaledMatmul(
|
||||
nodes=match,
|
||||
|
|
@ -488,6 +418,8 @@ class _ScaledMatmul(_Matmul):
|
|||
result_scale_node=get_arg(mm_node, 5, None),
|
||||
out_dtype=get_arg(mm_node, 6, None),
|
||||
use_fast_accum=get_arg(mm_node, 7, False),
|
||||
pre_mm_reshape=pre_mm_reshape,
|
||||
post_mm_reshape=post_mm_reshape,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -506,8 +438,8 @@ def _find_reshape_mm_reshape(node: torch.fx.Node) -> list[_Matmul]:
|
|||
# Since the reshape -> mm -> reshape pattern would be subsumed into
|
||||
# the fused op, we only match the patterns where the shape of the
|
||||
# second reshape is matches the mm result produced by the fused op.
|
||||
matmul_input_node = cast(torch.fx.Node, node.args[0])
|
||||
B_node = cast(torch.fx.Node, mm_node.args[1])
|
||||
matmul_input_node = cast("torch.fx.Node", node.args[0])
|
||||
B_node = cast("torch.fx.Node", mm_node.args[1])
|
||||
matmul_out_shape = torch.Size(
|
||||
[
|
||||
*_get_tensor(matmul_input_node).shape[:-1],
|
||||
|
|
@ -576,7 +508,7 @@ def _insert_fused_all_gather_matmul(
|
|||
kwargs={"return_A": True},
|
||||
)
|
||||
elif mm_type == _ScaledMatmul:
|
||||
scaled_matmuls = cast(list[_ScaledMatmul], matmuls)
|
||||
scaled_matmuls = cast("list[_ScaledMatmul]", matmuls)
|
||||
return graph.call_function(
|
||||
torch.ops.symm_mem.fused_all_gather_scaled_matmul.default,
|
||||
args=(
|
||||
|
|
@ -707,7 +639,55 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None:
|
|||
fused_node.prepend(node)
|
||||
|
||||
|
||||
def _scatter_dim_after_reshape(
|
||||
reshape_node: torch.fx.Node, orig_scatter_dim: int
|
||||
) -> int:
|
||||
"""
|
||||
Given a reshape node and the original scatter dim for the target tensor,
|
||||
returns the new scatter dim for the reshaped tensor.
|
||||
"""
|
||||
# if there was no pre-mm reshape, scatter dim will not change.
|
||||
if not reshape_node:
|
||||
return orig_scatter_dim
|
||||
|
||||
reshape_op_output_tensor = _get_tensor(reshape_node)
|
||||
assert reshape_op_output_tensor.ndim == 2, (
|
||||
"reshape must produce 2D tensor for scaled_mm"
|
||||
)
|
||||
|
||||
assert len(reshape_node.args) >= 1, "reshape node must have at least 1 arg"
|
||||
input_tensor_node = cast(torch.fx.Node, reshape_node.args[0])
|
||||
reshape_op_input_tensor = _get_tensor(input_tensor_node)
|
||||
assert reshape_op_input_tensor.ndim > reshape_op_output_tensor.ndim, (
|
||||
"reshape must be from 3D+ to 2D"
|
||||
)
|
||||
|
||||
# Note: for a N-D tensor to be reshaped into 2D, either the leading dims or ending dims must
|
||||
# be collapsed to a single dim. First determine which of these happened.
|
||||
input_shape = reshape_op_input_tensor.shape
|
||||
output_shape = reshape_op_output_tensor.shape
|
||||
leading_dims_collapsed = output_shape[0] == prod(input_shape[:-1])
|
||||
|
||||
# Case 1: scatter dim 0 always maps to 0 after any reshape from 3D+ to 2D, regardless if
|
||||
# leading dims or ending dims were collapsed.
|
||||
if orig_scatter_dim == 0:
|
||||
return 0
|
||||
|
||||
# Case 2: scatter dim "ndim-1" always maps to 1 after any reshape from 3D+ to 2D, regardless if
|
||||
# leading dims or ending dims were collapsed.
|
||||
if orig_scatter_dim == reshape_op_input_tensor.ndim - 1:
|
||||
return 1
|
||||
|
||||
# Case 3: scatter dim was one of the middle dims (between 0 and ndim-1).
|
||||
# if the leading dims were collapsed, the new scatter dim will be 0.
|
||||
# if the ending dims were collapsed, the new scatter dim will be 1.
|
||||
return 0 if leading_dims_collapsed else 1
|
||||
|
||||
|
||||
def _find_producer_matmul(node: torch.fx.Node) -> Optional[_Matmul]:
|
||||
"""
|
||||
Returns producer matmul node if found, otherwise returns None.
|
||||
"""
|
||||
if node.target == aten.mm.default:
|
||||
return _Matmul.from_match(match=[node])
|
||||
elif node.target == aten._scaled_mm.default:
|
||||
|
|
@ -738,13 +718,21 @@ def _insert_fused_matmul_reduce_scatter(
|
|||
graph: torch.fx.Graph,
|
||||
matmul: _Matmul,
|
||||
reduce_op: str,
|
||||
scatter_dim: int,
|
||||
orig_scatter_dim: int,
|
||||
group_name: str,
|
||||
scatter_dim_after_reshape: int, # only used for reshape -> scaled_mm -> reshape pattern
|
||||
output_shape: list[int], # only used for reshape -> scaled_mm -> reshape pattern
|
||||
) -> torch.fx.Node:
|
||||
if type(matmul) == _Matmul:
|
||||
return graph.call_function(
|
||||
torch.ops.symm_mem.fused_matmul_reduce_scatter.default,
|
||||
args=(matmul.A_node, matmul.B_node, reduce_op, scatter_dim, group_name),
|
||||
args=(
|
||||
matmul.A_node,
|
||||
matmul.B_node,
|
||||
reduce_op,
|
||||
orig_scatter_dim,
|
||||
group_name,
|
||||
),
|
||||
)
|
||||
elif type(matmul) == _ScaledMatmul:
|
||||
return graph.call_function(
|
||||
|
|
@ -755,8 +743,10 @@ def _insert_fused_matmul_reduce_scatter(
|
|||
matmul.A_scale_node,
|
||||
matmul.B_scale_node,
|
||||
reduce_op,
|
||||
scatter_dim,
|
||||
orig_scatter_dim,
|
||||
scatter_dim_after_reshape,
|
||||
group_name,
|
||||
output_shape,
|
||||
matmul.bias_node,
|
||||
matmul.result_scale_node,
|
||||
matmul.out_dtype,
|
||||
|
|
@ -767,7 +757,7 @@ def _insert_fused_matmul_reduce_scatter(
|
|||
raise AssertionError(f"Unexpected matmul match type: {type(matmul)}")
|
||||
|
||||
|
||||
def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
|
||||
def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> bool:
|
||||
"""
|
||||
Fused the pattern
|
||||
|
||||
|
|
@ -778,19 +768,24 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
|
|||
torch.ops.symm_mem.fused_matmul_reduce_scatter(
|
||||
A, B, scatter_dim, group_name,
|
||||
)
|
||||
|
||||
Returns boolean indicating if fusion was successful or not.
|
||||
"""
|
||||
if (
|
||||
not torch.distributed.is_available()
|
||||
or not torch.distributed.is_nccl_available()
|
||||
):
|
||||
return
|
||||
log.debug(
|
||||
"torch.distributed is not available, skipping fuse_matmul_reduce_scatter fusion"
|
||||
)
|
||||
return False
|
||||
|
||||
from torch.distributed._symmetric_memory import (
|
||||
is_symm_mem_enabled_for_group,
|
||||
restride_A_for_fused_matmul_reduce_scatter,
|
||||
)
|
||||
|
||||
input_node, _rs_node, rs_res_node, reduce_op, scatter_dim, group_name = (
|
||||
input_node, _rs_node, rs_res_node, reduce_op, orig_scatter_dim, group_name = (
|
||||
reduce_scatter.input_node,
|
||||
reduce_scatter.rs_node,
|
||||
reduce_scatter.res_node,
|
||||
|
|
@ -800,40 +795,81 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
|
|||
)
|
||||
|
||||
if not is_symm_mem_enabled_for_group(group_name):
|
||||
return
|
||||
log.debug(
|
||||
"symmetric memory is not enabled for process group %s, skipping fuse_matmul_reduce_scatter fusion",
|
||||
group_name,
|
||||
)
|
||||
return False
|
||||
|
||||
# Currently fused_matmul_reduce_scatter doesn't return the matmul result,
|
||||
# so we can't apply the fusion if the matmul result is used by multiple
|
||||
# users. This is not a fundamental limitation of the fused op and can be
|
||||
# addressed if needed.
|
||||
if len(input_node.users) != 1:
|
||||
return
|
||||
log.debug(
|
||||
"matmul result has more than one user, skipping fused_matmul_reduce_scatter fusion."
|
||||
)
|
||||
return False
|
||||
|
||||
matmul = _find_producer_matmul(input_node)
|
||||
if matmul is None:
|
||||
return
|
||||
log.debug(
|
||||
"no producer matmul found for reduce scatter, skipping fuse_matmul_reduce_scatter fusion"
|
||||
)
|
||||
return False
|
||||
|
||||
if rs_res_node in matmul.arg_ancestor_nodes:
|
||||
return
|
||||
log.debug(
|
||||
"reduce-scatter result node is an ancestor of matmul, skipping fuse_matmul_reduce_scatter fusion"
|
||||
)
|
||||
return False
|
||||
|
||||
# We need to track 3 values for the fused scaled mm reduce scatter implementation:
|
||||
# 1. The scatter dim before the reshape, which was assigned using the original (a,b,c) @ (c,d) = (a,b,d) dims.
|
||||
# 2. The scatter dim after the reshape, to use when we are doing the 2D (a*b,c) @ (c,d) = (a,b,d) scaled mm op.
|
||||
# 3. Store expected potentially 3D+ mm output shape, so we can reshape the 2D mm output to the intended
|
||||
# 3D+ shape before applying reduce-scatter, and to prevent shape erros with subsequent ops.
|
||||
|
||||
# If 'A' was reshaped from 3D+ -> 2D for the mm, we need to determine the new scattter dim after the reshape
|
||||
# for the fused matmul reduce scatter implementation to use.
|
||||
if matmul.pre_mm_reshape:
|
||||
scatter_dim_after_maybe_reshape = _scatter_dim_after_reshape(
|
||||
matmul.pre_mm_reshape, orig_scatter_dim
|
||||
)
|
||||
else:
|
||||
scatter_dim_after_maybe_reshape = orig_scatter_dim
|
||||
|
||||
# If the 2D mm output was reshaped from 2D -> 3D+, we need to store the intended output shape for the
|
||||
# fused matmul reduce scatter implementation to use.
|
||||
if matmul.post_mm_reshape:
|
||||
output_shape = list(_get_tensor(matmul.post_mm_reshape).shape)
|
||||
else:
|
||||
A_orig_shape = list(_get_tensor(matmul.A_node).shape)
|
||||
B_shape = list(_get_tensor(matmul.B_node).shape)
|
||||
output_shape = [*A_orig_shape[:-1], B_shape[-1]]
|
||||
|
||||
graph = rs_res_node.graph
|
||||
with graph.inserting_before(rs_res_node):
|
||||
# Restride A tensor before fused op, for optimal perf in fused matmul reduce scatter
|
||||
if "val" in matmul.A_node.meta:
|
||||
restrided = restride_A_for_fused_matmul_reduce_scatter(
|
||||
_get_tensor(matmul.A_node),
|
||||
scatter_dim,
|
||||
scatter_dim_after_maybe_reshape,
|
||||
)
|
||||
matmul.A_node = graph.call_function(
|
||||
inductor_prims.force_stride_order,
|
||||
args=(matmul.A_node, restrided.stride()),
|
||||
)
|
||||
|
||||
# Replace matched subgraph with fused matmul reduce scatter node
|
||||
fused_node = _insert_fused_matmul_reduce_scatter(
|
||||
graph,
|
||||
matmul,
|
||||
reduce_op,
|
||||
scatter_dim,
|
||||
orig_scatter_dim,
|
||||
group_name,
|
||||
scatter_dim_after_maybe_reshape,
|
||||
output_shape,
|
||||
)
|
||||
reduce_scatter.replace_with(fused_node)
|
||||
reduce_scatter.erase()
|
||||
|
|
@ -848,6 +884,9 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
|
|||
if order[node] > order[fused_node]:
|
||||
fused_node.prepend(node)
|
||||
|
||||
log.debug("successfully fused matmul reduce scatter")
|
||||
return True
|
||||
|
||||
|
||||
def _get_node_to_ancestors(
|
||||
graph: torch.fx.Graph,
|
||||
|
|
@ -948,8 +987,19 @@ def micro_pipeline_tp_pass(graph: torch.fx.Graph):
|
|||
x for x in reduce_scatters if x.rs_node not in unexposed_collectives
|
||||
]
|
||||
|
||||
if not all_gathers and not reduce_scatters:
|
||||
raise AssertionError(
|
||||
"async TP found no matching all-gather/reduce-scatter patterns for fusion"
|
||||
)
|
||||
|
||||
# TODO: raise an exception if we're using async TP but failed to fuse any all-gather-matmuls
|
||||
for all_gather in all_gathers:
|
||||
fuse_all_gather_matmul(all_gather)
|
||||
|
||||
fused_reduce_scatters = False
|
||||
for reduce_scatter in reduce_scatters:
|
||||
fuse_matmul_reduce_scatter(reduce_scatter)
|
||||
if fuse_matmul_reduce_scatter(reduce_scatter):
|
||||
fused_reduce_scatters = True
|
||||
|
||||
if reduce_scatters and not fused_reduce_scatters:
|
||||
raise AssertionError("no successful fusions of matul-reduce-scatters")
|
||||
|
|
|
|||
|
|
@ -275,7 +275,7 @@ def reduce_scatter_tensor(
|
|||
group_size = c10d._get_group_size_by_name(group_name)
|
||||
|
||||
assert self.size(scatter_dim) % group_size == 0, (
|
||||
f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
|
||||
f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size})"
|
||||
)
|
||||
if scatter_dim != 0:
|
||||
tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
|
||||
|
|
|
|||
|
|
@ -461,7 +461,7 @@ lib.define(
|
|||
lib.define(
|
||||
"fused_scaled_matmul_reduce_scatter("
|
||||
"Tensor A, Tensor B, Tensor A_scale, Tensor B_scale, "
|
||||
"str reduce_op, int scatter_dim, str group_name, "
|
||||
"str reduce_op, int orig_scatter_dim, int scatter_dim_after_maybe_reshape, str group_name, int[]? output_shape, "
|
||||
"Tensor? bias = None, "
|
||||
"Tensor? result_scale = None, "
|
||||
"ScalarType? out_dtype = None, "
|
||||
|
|
@ -1005,11 +1005,59 @@ def restride_A_shard_for_fused_all_gather_matmul(
|
|||
return make_contiguous_for_perm(t, perm)
|
||||
|
||||
|
||||
@torch.library.impl(lib, "fused_matmul_reduce_scatter", "CUDA")
|
||||
def _fused_matmul_reduce_scatter(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
reduce_op: str,
|
||||
scatter_dim: int,
|
||||
group_name: str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform the following logic with micro-pipelined computation and
|
||||
communication:
|
||||
|
||||
reduce_scatter_tensor(A @ B, reduce_op, scatter_dim, group_name)
|
||||
|
||||
Optimal stride order for A - if A.movedim(scatter_dim, 0) is contiguous, no
|
||||
extra copy is required for input layout transformation. Otherwise A needs
|
||||
to be copied once.
|
||||
"""
|
||||
if _is_test_mode:
|
||||
return _fused_matmul_reduce_scatter_fallback(
|
||||
A, B, reduce_op, scatter_dim, group_name
|
||||
)
|
||||
|
||||
with torch.profiler.record_function("fused_matmul_reduce_scatter"):
|
||||
return _fused_matmul_reduce_scatter_impl(
|
||||
mm_out_op=torch.ops.aten.mm.out,
|
||||
A=A,
|
||||
B=B,
|
||||
kwargs={},
|
||||
out_dtype=A.dtype,
|
||||
reduce_op=reduce_op,
|
||||
scatter_dim=scatter_dim,
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
|
||||
@torch.library.impl(lib, "fused_matmul_reduce_scatter", "Meta")
|
||||
def _fused_matmul_reduce_scatter_fallback(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
reduce_op: str,
|
||||
scatter_dim: int,
|
||||
group_name: str,
|
||||
) -> torch.Tensor:
|
||||
res = funcol.reduce_scatter_tensor(A @ B, reduce_op, scatter_dim, group_name)
|
||||
res = funcol.wait_tensor(res)
|
||||
return res
|
||||
|
||||
|
||||
def _fused_matmul_reduce_scatter_impl(
|
||||
mm_out_op: torch._ops.OpOverload,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
kwargs: dict[str, Any],
|
||||
out_dtype: Optional[torch.dtype],
|
||||
reduce_op: str,
|
||||
|
|
@ -1040,29 +1088,9 @@ def _fused_matmul_reduce_scatter_impl(
|
|||
x = x.flatten(0, -2)
|
||||
A_shards = x.chunk(group.size())
|
||||
|
||||
A_scale_shards = None
|
||||
if A_scale is None:
|
||||
pass
|
||||
elif A_scale.numel() == 1:
|
||||
A_scale_shards = [A_scale] * group.size()
|
||||
else:
|
||||
if A_scale.shape[:-1] != A.shape[:-1]:
|
||||
raise ValueError(
|
||||
"For row-wise scaling, the leading dims of A_scale "
|
||||
"must match the leading dims of A "
|
||||
f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})"
|
||||
)
|
||||
A_scale = A_scale.movedim(scatter_dim, 0).contiguous().flatten(0, -2)
|
||||
A_scale_shards = list(A_scale.chunk(group.size()))
|
||||
|
||||
# Computing block-wise matmul along the first dim of A
|
||||
def chunk_producer(rank: int, out: torch.Tensor) -> None:
|
||||
if A_scale_shards is not None:
|
||||
mm_out_op(
|
||||
A_shards[rank], B, scale_a=A_scale_shards[rank], **kwargs, out=out
|
||||
)
|
||||
else:
|
||||
mm_out_op(A_shards[rank], B, **kwargs, out=out)
|
||||
mm_out_op(A_shards[rank], B, **kwargs, out=out)
|
||||
|
||||
stacked_partials = x.new_empty(x.shape[0], B.shape[1], dtype=out_dtype or A.dtype)
|
||||
|
||||
|
|
@ -1081,53 +1109,57 @@ def _fused_matmul_reduce_scatter_impl(
|
|||
)
|
||||
|
||||
|
||||
@torch.library.impl(lib, "fused_matmul_reduce_scatter", "Meta")
|
||||
def _fused_matmul_reduce_scatter_fallback(
|
||||
@torch.library.impl(lib, "fused_scaled_matmul_reduce_scatter", "CUDA")
|
||||
def _fused_scaled_matmul_reduce_scatter(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
reduce_op: str,
|
||||
scatter_dim: int,
|
||||
orig_scatter_dim: int,
|
||||
scatter_dim_after_maybe_reshape: int,
|
||||
group_name: str,
|
||||
output_shape: list[int],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
result_scale: Optional[torch.Tensor] = None,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
use_fast_accum: bool = False,
|
||||
) -> torch.Tensor:
|
||||
res = funcol.reduce_scatter_tensor(A @ B, reduce_op, scatter_dim, group_name)
|
||||
res = funcol.wait_tensor(res)
|
||||
return res
|
||||
|
||||
|
||||
@torch.library.impl(lib, "fused_matmul_reduce_scatter", "CUDA")
|
||||
def _fused_matmul_reduce_scatter(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
reduce_op: str,
|
||||
scatter_dim: int,
|
||||
group_name: str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform the following logic with micro-pipelined computation and
|
||||
communication:
|
||||
|
||||
reduce_scatter_tensor(A @ B, reduce_op, scatter_dim, group_name)
|
||||
|
||||
Optimal stride order for A - if A.movedim(scatter_dim, 0) is contiguous, no
|
||||
extra copy is required for input layout transformation. Otherwise A needs
|
||||
to be copied once.
|
||||
"""
|
||||
if _is_test_mode:
|
||||
return _fused_matmul_reduce_scatter_fallback(
|
||||
A, B, reduce_op, scatter_dim, group_name
|
||||
return _fused_scaled_matmul_reduce_scatter_fallback(
|
||||
A,
|
||||
B,
|
||||
A_scale,
|
||||
B_scale,
|
||||
reduce_op,
|
||||
orig_scatter_dim,
|
||||
scatter_dim_after_maybe_reshape,
|
||||
group_name,
|
||||
output_shape,
|
||||
bias,
|
||||
result_scale,
|
||||
out_dtype,
|
||||
use_fast_accum,
|
||||
)
|
||||
|
||||
with torch.profiler.record_function("fused_matmul_reduce_scatter"):
|
||||
return _fused_matmul_reduce_scatter_impl(
|
||||
mm_out_op=torch.ops.aten.mm.out,
|
||||
with torch.profiler.record_function("fused_scaled_matmul_reduce_scatter"):
|
||||
return _fused_scaled_matmul_reduce_scatter_impl(
|
||||
mm_out_op=torch.ops.aten._scaled_mm.out,
|
||||
A=A,
|
||||
B=B,
|
||||
A_scale=None,
|
||||
kwargs={},
|
||||
out_dtype=A.dtype,
|
||||
A_scale=A_scale,
|
||||
kwargs={
|
||||
"scale_b": B_scale,
|
||||
"bias": bias,
|
||||
"scale_result": result_scale,
|
||||
"out_dtype": out_dtype,
|
||||
"use_fast_accum": use_fast_accum,
|
||||
},
|
||||
out_dtype=out_dtype,
|
||||
reduce_op=reduce_op,
|
||||
scatter_dim=scatter_dim,
|
||||
orig_scatter_dim=orig_scatter_dim,
|
||||
scatter_dim_after_maybe_reshape=scatter_dim_after_maybe_reshape,
|
||||
group_name=group_name,
|
||||
output_shape=output_shape,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1138,8 +1170,10 @@ def _fused_scaled_matmul_reduce_scatter_fallback(
|
|||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
reduce_op: str,
|
||||
scatter_dim: int,
|
||||
orig_scatter_dim: int,
|
||||
scatter_dim_after_maybe_reshape: int,
|
||||
group_name: str,
|
||||
output_shape: list[int],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
result_scale: Optional[torch.Tensor] = None,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
|
|
@ -1169,63 +1203,133 @@ def _fused_scaled_matmul_reduce_scatter_fallback(
|
|||
out_dtype,
|
||||
use_fast_accum,
|
||||
)
|
||||
C = C.view(*A.shape[:-1], B.shape[1])
|
||||
C = C.view(*output_shape[:-1], B.shape[1])
|
||||
res = funcol.reduce_scatter_tensor(
|
||||
C,
|
||||
reduce_op,
|
||||
scatter_dim,
|
||||
orig_scatter_dim, # need original scatter dim for 3D+ output tensor here
|
||||
group_name,
|
||||
)
|
||||
res = funcol.wait_tensor(res)
|
||||
return res
|
||||
|
||||
|
||||
@torch.library.impl(lib, "fused_scaled_matmul_reduce_scatter", "CUDA")
|
||||
def _fused_scaled_matmul_reduce_scatter(
|
||||
def _fused_scaled_matmul_reduce_scatter_impl(
|
||||
mm_out_op: torch._ops.OpOverload,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
kwargs: dict[str, Any],
|
||||
out_dtype: Optional[torch.dtype],
|
||||
reduce_op: str,
|
||||
scatter_dim: int,
|
||||
orig_scatter_dim: int,
|
||||
scatter_dim_after_maybe_reshape: int,
|
||||
group_name: str,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
result_scale: Optional[torch.Tensor] = None,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
use_fast_accum: bool = False,
|
||||
output_shape: list[int],
|
||||
) -> torch.Tensor:
|
||||
if _is_test_mode:
|
||||
return _fused_scaled_matmul_reduce_scatter_fallback(
|
||||
A,
|
||||
B,
|
||||
A_scale,
|
||||
B_scale,
|
||||
reduce_op,
|
||||
scatter_dim,
|
||||
group_name,
|
||||
bias,
|
||||
result_scale,
|
||||
out_dtype,
|
||||
use_fast_accum,
|
||||
)
|
||||
with torch.profiler.record_function("fused_matmul_reduce_scatter"):
|
||||
return _fused_matmul_reduce_scatter_impl(
|
||||
mm_out_op=torch.ops.aten._scaled_mm.out,
|
||||
A=A,
|
||||
B=B,
|
||||
A_scale=A_scale,
|
||||
kwargs={
|
||||
"scale_b": B_scale,
|
||||
"bias": bias,
|
||||
"scale_result": result_scale,
|
||||
"out_dtype": out_dtype,
|
||||
"use_fast_accum": use_fast_accum,
|
||||
},
|
||||
out_dtype=out_dtype,
|
||||
reduce_op=reduce_op,
|
||||
scatter_dim=scatter_dim,
|
||||
group_name=group_name,
|
||||
if A.dim() < 2:
|
||||
raise ValueError("A_shard must be a matrix")
|
||||
if (
|
||||
scatter_dim_after_maybe_reshape < 0
|
||||
or scatter_dim_after_maybe_reshape >= A.dim()
|
||||
):
|
||||
raise ValueError("Invalid scatter dim for 2D tensor input to scaled_mm")
|
||||
if orig_scatter_dim < 0 or orig_scatter_dim >= len(output_shape):
|
||||
raise ValueError("Invalid scatter dim for 3D+ output tensor")
|
||||
if B.dim() != 2:
|
||||
raise ValueError("B must be a matrix")
|
||||
if reduce_op == "sum":
|
||||
reduce_fn = partial(torch.sum, dim=0)
|
||||
elif reduce_op == "avg":
|
||||
reduce_fn = partial(torch.mean, dim=0)
|
||||
else:
|
||||
raise ValueError("reduce_op must be sum or avg")
|
||||
|
||||
group = c10d._resolve_process_group(group_name)
|
||||
|
||||
# Move scatter to first dim, then shard the tensor along the first dim, so the chunk producer
|
||||
# can perform matmuls along the first dim.
|
||||
A_with_scatter_dim_0 = A.movedim(scatter_dim_after_maybe_reshape, 0)
|
||||
|
||||
# To handle case where A is 3D+, reshape to 2D to prepare for mm which requires 2D inputs.
|
||||
A_with_scatter_dim_0 = A.flatten(0, -2)
|
||||
|
||||
# Parition A along the first dim to prepare for sharding across TP process group.
|
||||
A_shards = A_with_scatter_dim_0.chunk(group.size())
|
||||
|
||||
# Now that 'A' is sharded along the first dim, we need to update its scale(s) accordingly.
|
||||
# How we do this depends on if we are using tensorwise scaling, rowwise scaling, or no scaling.
|
||||
tensorwise_scaling = A_scale is not None and A_scale.numel() == 1
|
||||
rowwise_scaling = A_scale is not None and A_scale.numel() > 1
|
||||
|
||||
# For tensorwise scaling, the scale should be replicated so each shard has a copy.
|
||||
if tensorwise_scaling:
|
||||
A_scale_shards = [A_scale] * group.size()
|
||||
|
||||
# For rowwise scaling, we need to move the scatter dim to the first dim to match the
|
||||
# dim swap of the 'A' tensor. Then we can shard the scales along the first dim, just like
|
||||
# the 'A' tensor.
|
||||
elif rowwise_scaling:
|
||||
if A_scale.shape[:-1] != A.shape[:-1]:
|
||||
raise ValueError(
|
||||
"For row-wise scaling, the leading dims of A_scale "
|
||||
"must match the leading dims of A "
|
||||
f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})"
|
||||
)
|
||||
A_scale = (
|
||||
A_scale.movedim(scatter_dim_after_maybe_reshape, 0)
|
||||
.contiguous()
|
||||
.flatten(0, -2)
|
||||
)
|
||||
A_scale_shards = list(A_scale.chunk(group.size()))
|
||||
else:
|
||||
raise ValueError("A_scale cannot be none for scaled_mm")
|
||||
|
||||
# Computing block-wise matmul along the first dim of A
|
||||
def chunk_producer(rank: int, out: torch.Tensor) -> None:
|
||||
mm_out_op(A_shards[rank], B, scale_a=A_scale_shards[rank], **kwargs, out=out)
|
||||
|
||||
# Stacked partials will be the 2D outputs of the the pipelined scaled mm, and will
|
||||
# have the shape (A_with_scatter_dim_0_tensor.shape[0], B.shape[1]) to align with the formula:
|
||||
# (a*b,c) @ (c,d) = (a*b,d)
|
||||
stacked_partials = A_with_scatter_dim_0.new_empty(
|
||||
A_with_scatter_dim_0.shape[0], B.shape[1], dtype=out_dtype or A.dtype
|
||||
)
|
||||
|
||||
# Execute the pipelined mm/scaled_mm.
|
||||
_pipelined_produce_and_all2all(
|
||||
chunk_producer,
|
||||
stacked_partials,
|
||||
group_name,
|
||||
)
|
||||
|
||||
# We now need to transform the *unreduced* stacked 2D partial mm outputs to an *unreduced* 3D+ output,
|
||||
# then reduce-scatter. To do this, we first need to determine the shape of the unreduced 3D+ output,
|
||||
# to reshape our stacked partials so we can apply the reduce-scatter.
|
||||
#
|
||||
# The *unreduced* 3D+ tensor will have dim 0 = `group_size`, as we have `group_size` instances of
|
||||
# stacked partial outputs. The next dims will be A's leading dims (sharded along the original scatter dim),
|
||||
# as it was the left operand of the mm op. We can use -1 as the final dim of the view to populate the rest.
|
||||
stacked_partials_3D_leading_dims = [group.size()] + list(
|
||||
A_with_scatter_dim_0.shape[:-1]
|
||||
)
|
||||
stacked_partials_3D_leading_dims[orig_scatter_dim] //= group.size()
|
||||
|
||||
# Ensures that the transpose and reduction produce contiguous result
|
||||
# in a single reduction kernel.
|
||||
reduced_out = reduce_fn(
|
||||
# View 2D stacked partials as 3D+ tensor of shape (`group_size`, ...)
|
||||
stacked_partials.view(*stacked_partials_3D_leading_dims, -1)
|
||||
# Swap back the scatter dim (which we moved to 0, and now is `group_size`)
|
||||
.movedim(0, orig_scatter_dim),
|
||||
dim=orig_scatter_dim, # Reduce along the origal scatter dim (`group_size`)
|
||||
)
|
||||
|
||||
# Final 3D+ output shape must be scattered along original scatter dim as well.
|
||||
final_out_shape = [*output_shape[:-1], B.shape[-1]]
|
||||
final_out_shape[orig_scatter_dim] //= group.size()
|
||||
out = reduced_out.view(*final_out_shape)
|
||||
return out
|
||||
|
||||
|
||||
def restride_A_for_fused_matmul_reduce_scatter(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user