mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[async TP] insert reshape node to handle "reshape -> scaled mm -> reshape pattern" in async TP with rowwise scales (#148001)
Fixes https://github.com/pytorch/torchtitan/issues/864 ## Summary While testing torchtitan with float8 training with rowwise scaling + async TP, a [bug](https://github.com/pytorch/torchtitan/issues/864) was discovered. The symptom was the scaling factor dims did not match the dims of the tensor the scales were to be applied to. My [root cause analysis](https://github.com/pytorch/torchtitan/issues/864#issuecomment-2672465060) determined the reason is that when async TP graph manipulation constructs the `fused_scaled_matmul_reduce_scatter` op, it does not yet handle the "reshape -> scaled mm -> reshape" pattern used in torchao [here](ed361ff5c7/torchao/float8/float8_linear.py (L122-L124)) - specifically when row-wise scales are being used. ## TL;DR of root cause - When a Float8Tensor is reshaped, the scale is reshaped along with it so the dimensions are aligned. - In the graph manipulation logic of the micropipeline TP post grad pass, the scaled_mm `A tensor` node is referencing the tensor _before_ to the reshape op, but referencing the `A_scale` node _after_ the reshape op. ## Example - Concrete example: - `A tensor` is a Float8Tensor with shape (1,8192,2048) and scale of shape (1,8192,1) when a matmul op is called in torchao [here](8706d3f3b0/torchao/float8/float8_linear.py (L70)). Torchao does a reshape -> scaled mm -> reshape [here](ed361ff5c7/torchao/float8/float8_linear.py (L122)). When a Float8Tensor is reshaped, its scale is reshaped along with it [here](8706d3f3b0/torchao/float8/float8_ops.py (L152)). So the first reshape makes the "A tensor" (1,8192,2048) => (8192,2048) and the scale (1,8192,1) => (8192,1). - During post grad pass in async TP: - `A_node` has shape (1,8192,2048) (tensor from before this [reshape](ed361ff5c7/torchao/float8/float8_linear.py (L122))) - `A_scale` has shape (8192,1) (due to reshape op above, which caused the scale to be reshaped from (1,8192,1) => (8192,1)). ## Solution **Note:** the compiler inserts a `reciprocal` op after the reshape, so we can't simply use the node before the reshape as the `A_scale_node`, otherwise it will affect the numerics. - Short-term solution: if the specific pattern showne below is detected, insert a reshape node after the reciprocal, to reshape the reciprocal output back to the originals shape before the reshape. - reshape is just a view, so there should be no impact on performance ``` Before: reshape (a,bc,) to (a*b,c) -> reciprocal After: reshape (a,bc,) to (a*b,c) -> reciprocal -> reshape (a*b,c) to (a,b,c) ``` - Long-term solution: implement a `torch._scaled_matmul` which can support 3D+ `A tensor` ## Test plan - Added unit test which exercises this new path - Manually tested with torchtitan with float8 rowwise + async TP Pull Request resolved: https://github.com/pytorch/pytorch/pull/148001 Approved by: https://github.com/yifuwang
This commit is contained in:
parent
fd16311e7f
commit
b8efebe57d
|
|
@ -399,6 +399,69 @@ class MicroPipelineTPTest(TestCase):
|
||||||
self.assertIn("fused_scaled_matmul_reduce_scatter", code)
|
self.assertIn("fused_scaled_matmul_reduce_scatter", code)
|
||||||
self.assertNotIn("reduce_scatter_tensor", code)
|
self.assertNotIn("reduce_scatter_tensor", code)
|
||||||
|
|
||||||
|
@skipIfRocm
|
||||||
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
|
@parametrize("scatter_dim", [2])
|
||||||
|
@fresh_inductor_cache()
|
||||||
|
def test_fuse_scaled_matmul_reduce_scatter_rowwise_scales_reshape_mm_reshape(
|
||||||
|
self, scatter_dim
|
||||||
|
):
|
||||||
|
group = dist.group.WORLD
|
||||||
|
|
||||||
|
def reshape_mm_reshape(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
A_scale: torch.Tensor,
|
||||||
|
B_scale: torch.Tensor,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Performs a scaled_mm followed by a reduce scatter,
|
||||||
|
following the reshape -> scaled_mm -> reshape pattern.
|
||||||
|
"""
|
||||||
|
orig_shape = A.shape
|
||||||
|
|
||||||
|
# reshape tensor and scale together
|
||||||
|
A = A.reshape(-1, orig_shape[-1])
|
||||||
|
A_scale = A_scale.reshape(-1, A_scale.shape[-1])
|
||||||
|
A_scale = torch.reciprocal(A_scale)
|
||||||
|
|
||||||
|
C = torch._scaled_mm(A, B, A_scale, B_scale, out_dtype=out_dtype)
|
||||||
|
|
||||||
|
# reshape output to have same leading dims as original `A` tensor
|
||||||
|
C = C.view(*orig_shape[:-1], C.shape[-1])
|
||||||
|
return reduce_scatter_tensor(C, "sum", scatter_dim, group)
|
||||||
|
|
||||||
|
A = torch.rand(1, 16, 32, device="cuda").to(torch.float8_e4m3fn)
|
||||||
|
B = torch.rand(64, 32, device="cuda").to(torch.float8_e4m3fn).T
|
||||||
|
|
||||||
|
# A_scale = rowwise scales
|
||||||
|
A_scale = torch.full((1, 16, 1), 0.1, device="cuda")
|
||||||
|
|
||||||
|
# B_scale = rowwise scales transposed for A @ B^T
|
||||||
|
B_scale = torch.full((1, 64), 0.1, device="cuda")
|
||||||
|
|
||||||
|
gm = _make_post_grad_fx(
|
||||||
|
reshape_mm_reshape, A, B, A_scale, B_scale, torch.bfloat16
|
||||||
|
)
|
||||||
|
|
||||||
|
with _test_mode():
|
||||||
|
micro_pipeline_tp_pass(gm.graph)
|
||||||
|
|
||||||
|
self.assertIn("fused_scaled_matmul_reduce_scatter", str(gm.graph))
|
||||||
|
self.assertNotIn("reduce_scatter_tensor", str(gm.graph))
|
||||||
|
|
||||||
|
if torch.cuda.get_device_capability() < (8, 9):
|
||||||
|
return
|
||||||
|
|
||||||
|
with _test_mode():
|
||||||
|
compiled = torch.compile(reshape_mm_reshape)
|
||||||
|
code = run_and_get_triton_code(
|
||||||
|
compiled, A, B, A_scale, B_scale, torch.bfloat16
|
||||||
|
)
|
||||||
|
self.assertIn("fused_scaled_matmul_reduce_scatter", code)
|
||||||
|
self.assertNotIn("reduce_scatter_tensor", code)
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@parametrize("shard_dim", [0, 1])
|
@parametrize("shard_dim", [0, 1])
|
||||||
@fresh_inductor_cache()
|
@fresh_inductor_cache()
|
||||||
|
|
|
||||||
|
|
@ -386,12 +386,104 @@ class _ScaledMatmul(_Matmul):
|
||||||
return default
|
return default
|
||||||
return node.args[idx]
|
return node.args[idx]
|
||||||
|
|
||||||
|
def insert_reshape_op(node: torch.fx.Node):
|
||||||
|
"""
|
||||||
|
Given a reciprocal node with a parent reshape node,
|
||||||
|
insert a reshape node after the reciprocal node which reshapes
|
||||||
|
the reciprocal output back to the original shape before the first reshape.
|
||||||
|
|
||||||
|
Before:
|
||||||
|
reshape (a,bc,) to (a*b,c) -> reciprocal
|
||||||
|
|
||||||
|
After:
|
||||||
|
reshape (a,bc,) to (a*b,c) -> reciprocal -> reshape (a*b,c) to (a,b,c)
|
||||||
|
|
||||||
|
Returns the new reshape node.
|
||||||
|
"""
|
||||||
|
# ensure the given node matches the pattern described in the docstring
|
||||||
|
assert (
|
||||||
|
node.target == aten.reciprocal.default
|
||||||
|
), "Node must be a aten.reciprocal.default op"
|
||||||
|
assert len(node.all_input_nodes) == 1, "Node must have exactly one parent"
|
||||||
|
|
||||||
|
parent_node = node.all_input_nodes[0]
|
||||||
|
assert (
|
||||||
|
parent_node.target == aten.reshape.default
|
||||||
|
), "Parent node must be a aten.reshape.default op"
|
||||||
|
assert (
|
||||||
|
len(parent_node.all_input_nodes) == 1
|
||||||
|
), "Parent node must have exactly one input node"
|
||||||
|
|
||||||
|
parent_input_node = parent_node.all_input_nodes[0]
|
||||||
|
parent_input_shape = list(_get_tensor(parent_input_node).shape)
|
||||||
|
|
||||||
|
# insert reshape back to shape from before the parent reshape op
|
||||||
|
graph = node.graph
|
||||||
|
with graph.inserting_after(node):
|
||||||
|
reshape_node = graph.call_function(
|
||||||
|
aten.reshape.default, (node, parent_input_shape)
|
||||||
|
)
|
||||||
|
|
||||||
|
# ensure all users of original node (except the reshape node) now use the reshaped node instead
|
||||||
|
node_users = list(node.users)
|
||||||
|
for user in node_users:
|
||||||
|
if user != reshape_node:
|
||||||
|
user.replace_input_with(node, reshape_node)
|
||||||
|
|
||||||
|
return reshape_node
|
||||||
|
|
||||||
|
is_reshape_mm_reshape_pattern = match[0].target == aten.reshape.default
|
||||||
|
mm_node = match[1] if is_reshape_mm_reshape_pattern else match[0]
|
||||||
|
|
||||||
|
# `A_node` is pulled directly from match rather than `mm_node` because it needs to handle
|
||||||
|
# both of the following cases:
|
||||||
|
#
|
||||||
|
# Case 1: single node match (mm):
|
||||||
|
# - match[0].args[0] will be the "A tensor" node of scaled_mm
|
||||||
|
# - Has 2D shape
|
||||||
|
#
|
||||||
|
# Case 2: 3 node match (reshape -> mm -> reshape)
|
||||||
|
# - match[0].args[0] will be the "A tensor" input to the reshape op
|
||||||
|
# - Has 3D+ shape
|
||||||
|
A_node = cast(torch.fx.Node, match[0].args[0])
|
||||||
|
B_node = cast(torch.fx.Node, mm_node.args[1])
|
||||||
|
A_scale_node = cast(torch.fx.Node, mm_node.args[2])
|
||||||
|
B_scale_node = cast(torch.fx.Node, mm_node.args[3])
|
||||||
|
|
||||||
|
A_ndim = _get_tensor(A_node).ndim
|
||||||
|
A_scale_ndim = _get_tensor(A_scale_node).ndim
|
||||||
|
is_reciprocal_with_reshape_parent = (
|
||||||
|
A_scale_node.target == aten.reciprocal.default
|
||||||
|
and len(A_scale_node.all_input_nodes) == 1
|
||||||
|
and A_scale_node.all_input_nodes[0].target == aten.reshape.default
|
||||||
|
)
|
||||||
|
is_tensorwise_scaling = A_scale_ndim <= 1
|
||||||
|
|
||||||
|
# This is a temporary workaround to handle the reshape -> scaled_mm -> reshape
|
||||||
|
# pattern when scales are row-wise, and have been reshaped along with the target
|
||||||
|
# tensor. See https://github.com/pytorch/pytorch/pull/148001 for details.
|
||||||
|
#
|
||||||
|
# If tensor dim does not match scale dim, check if the scale node follows
|
||||||
|
# the "reshape -> reciprocal" pattern. If so, we can insert a reshape op after
|
||||||
|
# the reciprocal, to reshape the reciprocal back to the original shape before
|
||||||
|
# the first reshape op.
|
||||||
|
#
|
||||||
|
# TODO: remove this workaround once torch._scaled_matmul exists and can be used
|
||||||
|
# to implement a more robust long-term support for 3D+ scaled matmuls.
|
||||||
|
if (
|
||||||
|
is_reshape_mm_reshape_pattern
|
||||||
|
and A_ndim != A_scale_ndim
|
||||||
|
and not is_tensorwise_scaling
|
||||||
|
and is_reciprocal_with_reshape_parent
|
||||||
|
):
|
||||||
|
A_scale_node = insert_reshape_op(A_scale_node)
|
||||||
|
|
||||||
return _ScaledMatmul(
|
return _ScaledMatmul(
|
||||||
nodes=match,
|
nodes=match,
|
||||||
A_node=cast(torch.fx.Node, match[0].args[0]),
|
A_node=A_node,
|
||||||
B_node=cast(torch.fx.Node, mm_node.args[1]),
|
B_node=B_node,
|
||||||
A_scale_node=cast(torch.fx.Node, mm_node.args[2]),
|
A_scale_node=A_scale_node,
|
||||||
B_scale_node=cast(torch.fx.Node, mm_node.args[3]),
|
B_scale_node=B_scale_node,
|
||||||
bias_node=get_arg(mm_node, 4, None),
|
bias_node=get_arg(mm_node, 4, None),
|
||||||
result_scale_node=get_arg(mm_node, 5, None),
|
result_scale_node=get_arg(mm_node, 5, None),
|
||||||
out_dtype=get_arg(mm_node, 6, None),
|
out_dtype=get_arg(mm_node, 6, None),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user