[async TP] insert reshape node to handle "reshape -> scaled mm -> reshape pattern" in async TP with rowwise scales (#148001)

Fixes https://github.com/pytorch/torchtitan/issues/864

## Summary
While testing torchtitan with float8 training with rowwise scaling + async TP, a [bug](https://github.com/pytorch/torchtitan/issues/864) was discovered. The symptom was the scaling factor dims did not match the dims of the tensor the scales were to be applied to.

My [root cause analysis](https://github.com/pytorch/torchtitan/issues/864#issuecomment-2672465060) determined the reason is that when async TP graph manipulation constructs the `fused_scaled_matmul_reduce_scatter` op, it does not yet handle the "reshape -> scaled mm -> reshape" pattern used in torchao [here](ed361ff5c7/torchao/float8/float8_linear.py (L122-L124)) - specifically when row-wise scales are being used.

## TL;DR of root cause
- When a Float8Tensor is reshaped, the scale is reshaped along with it so the dimensions are aligned.
- In the graph manipulation logic of the micropipeline TP post grad pass, the scaled_mm `A tensor` node is referencing the tensor _before_ to the reshape op, but referencing the `A_scale` node _after_ the reshape op.

## Example
- Concrete example:
    - `A tensor` is a Float8Tensor with shape (1,8192,2048) and scale of shape (1,8192,1) when a matmul op is called in torchao [here](8706d3f3b0/torchao/float8/float8_linear.py (L70)). Torchao does a reshape -> scaled mm -> reshape [here](ed361ff5c7/torchao/float8/float8_linear.py (L122)). When a Float8Tensor is reshaped, its scale is reshaped along with it [here](8706d3f3b0/torchao/float8/float8_ops.py (L152)). So the first reshape makes the "A tensor" (1,8192,2048) => (8192,2048) and the scale (1,8192,1) => (8192,1).
    - During post grad pass in async TP:
        - `A_node` has shape (1,8192,2048) (tensor from before this [reshape](ed361ff5c7/torchao/float8/float8_linear.py (L122)))
        - `A_scale` has shape (8192,1) (due to reshape op above, which caused the scale to be reshaped from (1,8192,1) => (8192,1)).

## Solution

**Note:** the compiler inserts a `reciprocal` op after the reshape, so we can't simply use the node before the reshape as the `A_scale_node`, otherwise it will affect the numerics.

- Short-term solution: if the specific pattern showne below is detected, insert a reshape node after the reciprocal, to reshape the reciprocal output back to the originals shape before the reshape.
    - reshape is just a view, so there should be no impact on performance
```
Before:
    reshape (a,bc,) to (a*b,c) -> reciprocal

After:
    reshape (a,bc,) to (a*b,c) -> reciprocal -> reshape (a*b,c) to (a,b,c)
```

- Long-term solution: implement a `torch._scaled_matmul` which can support 3D+ `A tensor`

## Test plan
- Added unit test which exercises this new path
- Manually tested with torchtitan with float8 rowwise + async TP

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148001
Approved by: https://github.com/yifuwang
This commit is contained in:
Daniel Vega-Myhre 2025-03-01 06:38:39 +00:00 committed by PyTorch MergeBot
parent fd16311e7f
commit b8efebe57d
2 changed files with 159 additions and 4 deletions

View File

@ -399,6 +399,69 @@ class MicroPipelineTPTest(TestCase):
self.assertIn("fused_scaled_matmul_reduce_scatter", code)
self.assertNotIn("reduce_scatter_tensor", code)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@parametrize("scatter_dim", [2])
@fresh_inductor_cache()
def test_fuse_scaled_matmul_reduce_scatter_rowwise_scales_reshape_mm_reshape(
self, scatter_dim
):
group = dist.group.WORLD
def reshape_mm_reshape(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
out_dtype: torch.dtype,
) -> torch.Tensor:
"""
Performs a scaled_mm followed by a reduce scatter,
following the reshape -> scaled_mm -> reshape pattern.
"""
orig_shape = A.shape
# reshape tensor and scale together
A = A.reshape(-1, orig_shape[-1])
A_scale = A_scale.reshape(-1, A_scale.shape[-1])
A_scale = torch.reciprocal(A_scale)
C = torch._scaled_mm(A, B, A_scale, B_scale, out_dtype=out_dtype)
# reshape output to have same leading dims as original `A` tensor
C = C.view(*orig_shape[:-1], C.shape[-1])
return reduce_scatter_tensor(C, "sum", scatter_dim, group)
A = torch.rand(1, 16, 32, device="cuda").to(torch.float8_e4m3fn)
B = torch.rand(64, 32, device="cuda").to(torch.float8_e4m3fn).T
# A_scale = rowwise scales
A_scale = torch.full((1, 16, 1), 0.1, device="cuda")
# B_scale = rowwise scales transposed for A @ B^T
B_scale = torch.full((1, 64), 0.1, device="cuda")
gm = _make_post_grad_fx(
reshape_mm_reshape, A, B, A_scale, B_scale, torch.bfloat16
)
with _test_mode():
micro_pipeline_tp_pass(gm.graph)
self.assertIn("fused_scaled_matmul_reduce_scatter", str(gm.graph))
self.assertNotIn("reduce_scatter_tensor", str(gm.graph))
if torch.cuda.get_device_capability() < (8, 9):
return
with _test_mode():
compiled = torch.compile(reshape_mm_reshape)
code = run_and_get_triton_code(
compiled, A, B, A_scale, B_scale, torch.bfloat16
)
self.assertIn("fused_scaled_matmul_reduce_scatter", code)
self.assertNotIn("reduce_scatter_tensor", code)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@parametrize("shard_dim", [0, 1])
@fresh_inductor_cache()

View File

@ -386,12 +386,104 @@ class _ScaledMatmul(_Matmul):
return default
return node.args[idx]
def insert_reshape_op(node: torch.fx.Node):
"""
Given a reciprocal node with a parent reshape node,
insert a reshape node after the reciprocal node which reshapes
the reciprocal output back to the original shape before the first reshape.
Before:
reshape (a,bc,) to (a*b,c) -> reciprocal
After:
reshape (a,bc,) to (a*b,c) -> reciprocal -> reshape (a*b,c) to (a,b,c)
Returns the new reshape node.
"""
# ensure the given node matches the pattern described in the docstring
assert (
node.target == aten.reciprocal.default
), "Node must be a aten.reciprocal.default op"
assert len(node.all_input_nodes) == 1, "Node must have exactly one parent"
parent_node = node.all_input_nodes[0]
assert (
parent_node.target == aten.reshape.default
), "Parent node must be a aten.reshape.default op"
assert (
len(parent_node.all_input_nodes) == 1
), "Parent node must have exactly one input node"
parent_input_node = parent_node.all_input_nodes[0]
parent_input_shape = list(_get_tensor(parent_input_node).shape)
# insert reshape back to shape from before the parent reshape op
graph = node.graph
with graph.inserting_after(node):
reshape_node = graph.call_function(
aten.reshape.default, (node, parent_input_shape)
)
# ensure all users of original node (except the reshape node) now use the reshaped node instead
node_users = list(node.users)
for user in node_users:
if user != reshape_node:
user.replace_input_with(node, reshape_node)
return reshape_node
is_reshape_mm_reshape_pattern = match[0].target == aten.reshape.default
mm_node = match[1] if is_reshape_mm_reshape_pattern else match[0]
# `A_node` is pulled directly from match rather than `mm_node` because it needs to handle
# both of the following cases:
#
# Case 1: single node match (mm):
# - match[0].args[0] will be the "A tensor" node of scaled_mm
# - Has 2D shape
#
# Case 2: 3 node match (reshape -> mm -> reshape)
# - match[0].args[0] will be the "A tensor" input to the reshape op
# - Has 3D+ shape
A_node = cast(torch.fx.Node, match[0].args[0])
B_node = cast(torch.fx.Node, mm_node.args[1])
A_scale_node = cast(torch.fx.Node, mm_node.args[2])
B_scale_node = cast(torch.fx.Node, mm_node.args[3])
A_ndim = _get_tensor(A_node).ndim
A_scale_ndim = _get_tensor(A_scale_node).ndim
is_reciprocal_with_reshape_parent = (
A_scale_node.target == aten.reciprocal.default
and len(A_scale_node.all_input_nodes) == 1
and A_scale_node.all_input_nodes[0].target == aten.reshape.default
)
is_tensorwise_scaling = A_scale_ndim <= 1
# This is a temporary workaround to handle the reshape -> scaled_mm -> reshape
# pattern when scales are row-wise, and have been reshaped along with the target
# tensor. See https://github.com/pytorch/pytorch/pull/148001 for details.
#
# If tensor dim does not match scale dim, check if the scale node follows
# the "reshape -> reciprocal" pattern. If so, we can insert a reshape op after
# the reciprocal, to reshape the reciprocal back to the original shape before
# the first reshape op.
#
# TODO: remove this workaround once torch._scaled_matmul exists and can be used
# to implement a more robust long-term support for 3D+ scaled matmuls.
if (
is_reshape_mm_reshape_pattern
and A_ndim != A_scale_ndim
and not is_tensorwise_scaling
and is_reciprocal_with_reshape_parent
):
A_scale_node = insert_reshape_op(A_scale_node)
return _ScaledMatmul(
nodes=match,
A_node=cast(torch.fx.Node, match[0].args[0]),
B_node=cast(torch.fx.Node, mm_node.args[1]),
A_scale_node=cast(torch.fx.Node, mm_node.args[2]),
B_scale_node=cast(torch.fx.Node, mm_node.args[3]),
A_node=A_node,
B_node=B_node,
A_scale_node=A_scale_node,
B_scale_node=B_scale_node,
bias_node=get_arg(mm_node, 4, None),
result_scale_node=get_arg(mm_node, 5, None),
out_dtype=get_arg(mm_node, 6, None),