[Inductor][Optimus] Add move view after cat aten pattern (#149178)

Summary:
Add aten pattern to move the view/reshape out of split cat, further reduce the number of kernels.

context: https://docs.google.com/document/d/1G2qFcQu1K7VXbz2uPe0CS2aBirnwtwI_B8lxmlBlAPQ/edit?tab=t.0

Test Plan:
### how to enable
Add the following patterns to the post grad
```
        post_grad_fusion_options={
            "normalization_aten_pass": {},
            "move_view_after_cat_aten_pass": {},
        },
```

### unit test
```
buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:split_cat_fx_aten_passes -- test_move_view_after_cat_aten
```

Buck UI: https://www.internalfb.com/buck2/3c5451be-c63a-4794-8d6b-103ecac78905
Test UI: https://www.internalfb.com/intern/testinfra/testrun/6192449704507267

### local reproduce

```
buck2 run mode/opt scripts/shuaiyang:test -- --flow_id 691990503 --use_synthetic_data --optimus
```
https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/mengluy/2025-03-13-20-59-34/trace.json.gz&bucket=gpu_traces

### E2E

baseline

f691990503

proposal

Differential Revision: D71177004

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149178
Approved by: https://github.com/Yuzhen11
This commit is contained in:
Menglu Yu 2025-03-20 04:07:25 +00:00 committed by PyTorch MergeBot
parent 95e71765f2
commit fe94d7da1a
2 changed files with 157 additions and 0 deletions

View File

@ -135,6 +135,46 @@ class TestSplitCatPartial(torch.nn.Module):
return cat
class TestMoveViewAferCat(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor):
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(
x, [1, 1, 1, 1, 1, 1, 1]
)
getitem_71 = split_with_sizes_1[0]
getitem_72 = split_with_sizes_1[1]
getitem_73 = split_with_sizes_1[2]
getitem_74 = split_with_sizes_1[3]
getitem_75 = split_with_sizes_1[4]
getitem_76 = split_with_sizes_1[5]
getitem_77 = split_with_sizes_1[6]
view_1 = torch.ops.aten.view.default(getitem_71, [8, 96])
view_2 = torch.ops.aten.view.default(getitem_72, [8, 96])
view_3 = torch.ops.aten.view.default(getitem_73, [8, 96])
view_4 = torch.ops.aten.view.default(getitem_74, [8, 96])
view_5 = torch.ops.aten.view.default(getitem_75, [8, 96])
view_6 = torch.ops.aten.view.default(getitem_76, [8, 96])
view_7 = torch.ops.aten.view.default(getitem_77, [8, 96])
clone = torch.ops.aten.clone.default(view_1)
cat = torch.ops.aten.cat.default(
[
view_1,
view_2,
view_3,
view_4,
view_5,
view_6,
view_7,
],
1,
)
cat_1 = torch.ops.aten.cat.default([clone, cat], 1)
return torch.cat([clone, cat_1], 1)
class TestSelectCat(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
@ -260,6 +300,30 @@ class TestSplitCatAten(TestCase):
self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8)
counters.clear()
@requires_cuda
@torch._inductor.config.patch(
pre_grad_fusion_options={},
post_grad_fusion_options={
"normalization_aten_pass": {},
"move_view_after_cat_aten_pass": {},
},
)
def test_move_view_after_cat_aten(self):
counters.clear()
inputs = [
torch.randn(7, 8, 96, device=torch.device(device=GPU_TYPE)),
]
module = TestMoveViewAferCat()
traced = torch.compile(module)
ref = module(*inputs)
res = traced(*inputs)
self.compare_pred(module, traced, inputs)
self.assertEqual(counters["inductor"]["normalization_aten_pass"], 4)
self.assertEqual(counters["inductor"]["move_view_after_cat_aten_pass"], 1)
self.assertEqual(ref, res, rtol=1e-8, atol=1e-8)
self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8)
counters.clear()
if __name__ == "__main__":
run_tests()

View File

@ -72,6 +72,7 @@ post_grad_pass_names = [
"pad_aten_mm_pass",
"split_cat_aten_pass",
"select_cat_aten_pass",
"move_view_after_cat_aten_pass",
]
for pass_name in pre_grad_pass_names:
@ -2864,3 +2865,95 @@ def move_reshape_out_of_split_stack(match: Match, *args, **kwargs):
remove_split_unbind_children(graph, stack_inputs) # type: ignore[arg-type]
remove_split_unbind_children(graph, split_users) # type: ignore[arg-type]
counters["inductor"]["move_reshape_out_of_split_stack_pass"] += 1
view_getitem_split_aten = ListOf(
CallFunction(
[torch.ops.aten.reshape.default],
CallFunction(
operator.getitem,
CallFunctionVarArgs(
torch.ops.aten.split_with_sizes.default, users=MULTIPLE
),
Ignored(),
_users=MULTIPLE,
),
Arg(),
_users=MULTIPLE,
),
partial=True,
)
@register_graph_pattern(
CallFunction(
torch.ops.aten.cat.default,
view_getitem_split_aten,
dim=Ignored(),
_users=MULTIPLE,
),
pass_dict=construct_pattern_matcher_pass("move_view_after_cat_aten_pass"),
)
def move_view_after_cat(match: Match, *args, **kwargs):
split_node = next(
node
for node in match.nodes
if node.target == torch.ops.aten.split_with_sizes.default
)
split_input, split_section, split_dim = _get_split_args_default(split_node)
split_users = list(split_node.users.keys())
getitem_indices = [
getitem.args[1] for getitem in split_users if getitem.target == operator.getitem
]
if not is_sorted_and_consecutive(getitem_indices): # type: ignore[arg-type]
return
cat_nodes = [
node for node in match.nodes if node.target == torch.ops.aten.cat.default
]
graph = match.graph
for cat_node in cat_nodes:
if not is_node_meta_valid(cat_node):
log.debug("example value absent for node: %s", cat_node)
continue
cat_dim = _get_dim(cat_node)
cat_inputs = get_arg_value(cat_node, 0, "tensors") # type: ignore[union-attr]
# we only consider the following special case
if len(cat_inputs) != len(split_section):
continue
# check if the cat inputs are all the view nodes
if not all(
view_node.target == torch.ops.aten.reshape.default
for view_node in cat_inputs
):
continue
# check if the view nodes are all from getitem nodes
if not all(
view_node.args[0].target == operator.getitem for view_node in cat_inputs
):
continue
view_indices = [view.args[0].args[1] for view in cat_inputs]
if not is_sorted_and_consecutive(view_indices): # type: ignore[arg-type]
continue
if cat_dim != split_dim:
# construct permute node
permute_list = list(range(len(cat_node.meta["val"].shape) + 1))
permute_list[split_dim], permute_list[cat_dim] = (
permute_list[cat_dim],
permute_list[split_dim],
)
permute_node = graph.call_function(
torch.ops.aten.permute.default,
args=(split_input, permute_list),
)
else:
permute_node = split_input
with graph.inserting_before(cat_node):
view_node = graph.call_function(
torch.ops.aten.reshape.default,
args=(permute_node, list(cat_node.meta["val"].shape)),
)
cat_node.replace_all_uses_with(view_node)
view_node.meta.update(cat_node.meta)
graph.erase_node(cat_node)
counters["inductor"]["move_view_after_cat_aten_pass"] += 1