mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor][Optimus] Add move view after cat aten pattern (#149178)
Summary: Add aten pattern to move the view/reshape out of split cat, further reduce the number of kernels. context: https://docs.google.com/document/d/1G2qFcQu1K7VXbz2uPe0CS2aBirnwtwI_B8lxmlBlAPQ/edit?tab=t.0 Test Plan: ### how to enable Add the following patterns to the post grad ``` post_grad_fusion_options={ "normalization_aten_pass": {}, "move_view_after_cat_aten_pass": {}, }, ``` ### unit test ``` buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:split_cat_fx_aten_passes -- test_move_view_after_cat_aten ``` Buck UI: https://www.internalfb.com/buck2/3c5451be-c63a-4794-8d6b-103ecac78905 Test UI: https://www.internalfb.com/intern/testinfra/testrun/6192449704507267 ### local reproduce ``` buck2 run mode/opt scripts/shuaiyang:test -- --flow_id 691990503 --use_synthetic_data --optimus ``` https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/mengluy/2025-03-13-20-59-34/trace.json.gz&bucket=gpu_traces ### E2E baseline f691990503 proposal Differential Revision: D71177004 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149178 Approved by: https://github.com/Yuzhen11
This commit is contained in:
parent
95e71765f2
commit
fe94d7da1a
|
|
@ -135,6 +135,46 @@ class TestSplitCatPartial(torch.nn.Module):
|
|||
return cat
|
||||
|
||||
|
||||
class TestMoveViewAferCat(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(
|
||||
x, [1, 1, 1, 1, 1, 1, 1]
|
||||
)
|
||||
getitem_71 = split_with_sizes_1[0]
|
||||
getitem_72 = split_with_sizes_1[1]
|
||||
getitem_73 = split_with_sizes_1[2]
|
||||
getitem_74 = split_with_sizes_1[3]
|
||||
getitem_75 = split_with_sizes_1[4]
|
||||
getitem_76 = split_with_sizes_1[5]
|
||||
getitem_77 = split_with_sizes_1[6]
|
||||
view_1 = torch.ops.aten.view.default(getitem_71, [8, 96])
|
||||
view_2 = torch.ops.aten.view.default(getitem_72, [8, 96])
|
||||
view_3 = torch.ops.aten.view.default(getitem_73, [8, 96])
|
||||
view_4 = torch.ops.aten.view.default(getitem_74, [8, 96])
|
||||
view_5 = torch.ops.aten.view.default(getitem_75, [8, 96])
|
||||
view_6 = torch.ops.aten.view.default(getitem_76, [8, 96])
|
||||
view_7 = torch.ops.aten.view.default(getitem_77, [8, 96])
|
||||
clone = torch.ops.aten.clone.default(view_1)
|
||||
|
||||
cat = torch.ops.aten.cat.default(
|
||||
[
|
||||
view_1,
|
||||
view_2,
|
||||
view_3,
|
||||
view_4,
|
||||
view_5,
|
||||
view_6,
|
||||
view_7,
|
||||
],
|
||||
1,
|
||||
)
|
||||
cat_1 = torch.ops.aten.cat.default([clone, cat], 1)
|
||||
return torch.cat([clone, cat_1], 1)
|
||||
|
||||
|
||||
class TestSelectCat(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -260,6 +300,30 @@ class TestSplitCatAten(TestCase):
|
|||
self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8)
|
||||
counters.clear()
|
||||
|
||||
@requires_cuda
|
||||
@torch._inductor.config.patch(
|
||||
pre_grad_fusion_options={},
|
||||
post_grad_fusion_options={
|
||||
"normalization_aten_pass": {},
|
||||
"move_view_after_cat_aten_pass": {},
|
||||
},
|
||||
)
|
||||
def test_move_view_after_cat_aten(self):
|
||||
counters.clear()
|
||||
inputs = [
|
||||
torch.randn(7, 8, 96, device=torch.device(device=GPU_TYPE)),
|
||||
]
|
||||
module = TestMoveViewAferCat()
|
||||
traced = torch.compile(module)
|
||||
ref = module(*inputs)
|
||||
res = traced(*inputs)
|
||||
self.compare_pred(module, traced, inputs)
|
||||
self.assertEqual(counters["inductor"]["normalization_aten_pass"], 4)
|
||||
self.assertEqual(counters["inductor"]["move_view_after_cat_aten_pass"], 1)
|
||||
self.assertEqual(ref, res, rtol=1e-8, atol=1e-8)
|
||||
self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8)
|
||||
counters.clear()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ post_grad_pass_names = [
|
|||
"pad_aten_mm_pass",
|
||||
"split_cat_aten_pass",
|
||||
"select_cat_aten_pass",
|
||||
"move_view_after_cat_aten_pass",
|
||||
]
|
||||
|
||||
for pass_name in pre_grad_pass_names:
|
||||
|
|
@ -2864,3 +2865,95 @@ def move_reshape_out_of_split_stack(match: Match, *args, **kwargs):
|
|||
remove_split_unbind_children(graph, stack_inputs) # type: ignore[arg-type]
|
||||
remove_split_unbind_children(graph, split_users) # type: ignore[arg-type]
|
||||
counters["inductor"]["move_reshape_out_of_split_stack_pass"] += 1
|
||||
|
||||
|
||||
view_getitem_split_aten = ListOf(
|
||||
CallFunction(
|
||||
[torch.ops.aten.reshape.default],
|
||||
CallFunction(
|
||||
operator.getitem,
|
||||
CallFunctionVarArgs(
|
||||
torch.ops.aten.split_with_sizes.default, users=MULTIPLE
|
||||
),
|
||||
Ignored(),
|
||||
_users=MULTIPLE,
|
||||
),
|
||||
Arg(),
|
||||
_users=MULTIPLE,
|
||||
),
|
||||
partial=True,
|
||||
)
|
||||
|
||||
|
||||
@register_graph_pattern(
|
||||
CallFunction(
|
||||
torch.ops.aten.cat.default,
|
||||
view_getitem_split_aten,
|
||||
dim=Ignored(),
|
||||
_users=MULTIPLE,
|
||||
),
|
||||
pass_dict=construct_pattern_matcher_pass("move_view_after_cat_aten_pass"),
|
||||
)
|
||||
def move_view_after_cat(match: Match, *args, **kwargs):
|
||||
split_node = next(
|
||||
node
|
||||
for node in match.nodes
|
||||
if node.target == torch.ops.aten.split_with_sizes.default
|
||||
)
|
||||
split_input, split_section, split_dim = _get_split_args_default(split_node)
|
||||
split_users = list(split_node.users.keys())
|
||||
getitem_indices = [
|
||||
getitem.args[1] for getitem in split_users if getitem.target == operator.getitem
|
||||
]
|
||||
if not is_sorted_and_consecutive(getitem_indices): # type: ignore[arg-type]
|
||||
return
|
||||
cat_nodes = [
|
||||
node for node in match.nodes if node.target == torch.ops.aten.cat.default
|
||||
]
|
||||
graph = match.graph
|
||||
for cat_node in cat_nodes:
|
||||
if not is_node_meta_valid(cat_node):
|
||||
log.debug("example value absent for node: %s", cat_node)
|
||||
continue
|
||||
cat_dim = _get_dim(cat_node)
|
||||
cat_inputs = get_arg_value(cat_node, 0, "tensors") # type: ignore[union-attr]
|
||||
# we only consider the following special case
|
||||
if len(cat_inputs) != len(split_section):
|
||||
continue
|
||||
# check if the cat inputs are all the view nodes
|
||||
if not all(
|
||||
view_node.target == torch.ops.aten.reshape.default
|
||||
for view_node in cat_inputs
|
||||
):
|
||||
continue
|
||||
# check if the view nodes are all from getitem nodes
|
||||
if not all(
|
||||
view_node.args[0].target == operator.getitem for view_node in cat_inputs
|
||||
):
|
||||
continue
|
||||
view_indices = [view.args[0].args[1] for view in cat_inputs]
|
||||
if not is_sorted_and_consecutive(view_indices): # type: ignore[arg-type]
|
||||
continue
|
||||
if cat_dim != split_dim:
|
||||
# construct permute node
|
||||
permute_list = list(range(len(cat_node.meta["val"].shape) + 1))
|
||||
permute_list[split_dim], permute_list[cat_dim] = (
|
||||
permute_list[cat_dim],
|
||||
permute_list[split_dim],
|
||||
)
|
||||
permute_node = graph.call_function(
|
||||
torch.ops.aten.permute.default,
|
||||
args=(split_input, permute_list),
|
||||
)
|
||||
else:
|
||||
permute_node = split_input
|
||||
|
||||
with graph.inserting_before(cat_node):
|
||||
view_node = graph.call_function(
|
||||
torch.ops.aten.reshape.default,
|
||||
args=(permute_node, list(cat_node.meta["val"].shape)),
|
||||
)
|
||||
cat_node.replace_all_uses_with(view_node)
|
||||
view_node.meta.update(cat_node.meta)
|
||||
graph.erase_node(cat_node)
|
||||
counters["inductor"]["move_view_after_cat_aten_pass"] += 1
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user