mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Inductor][Optimus] Add move view after cat aten pattern (#149178)
Summary: Add aten pattern to move the view/reshape out of split cat, further reduce the number of kernels. context: https://docs.google.com/document/d/1G2qFcQu1K7VXbz2uPe0CS2aBirnwtwI_B8lxmlBlAPQ/edit?tab=t.0 Test Plan: ### how to enable Add the following patterns to the post grad ``` post_grad_fusion_options={ "normalization_aten_pass": {}, "move_view_after_cat_aten_pass": {}, }, ``` ### unit test ``` buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:split_cat_fx_aten_passes -- test_move_view_after_cat_aten ``` Buck UI: https://www.internalfb.com/buck2/3c5451be-c63a-4794-8d6b-103ecac78905 Test UI: https://www.internalfb.com/intern/testinfra/testrun/6192449704507267 ### local reproduce ``` buck2 run mode/opt scripts/shuaiyang:test -- --flow_id 691990503 --use_synthetic_data --optimus ``` https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/mengluy/2025-03-13-20-59-34/trace.json.gz&bucket=gpu_traces ### E2E baseline f691990503 proposal Differential Revision: D71177004 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149178 Approved by: https://github.com/Yuzhen11
This commit is contained in:
parent
95e71765f2
commit
fe94d7da1a
|
|
@ -135,6 +135,46 @@ class TestSplitCatPartial(torch.nn.Module):
|
||||||
return cat
|
return cat
|
||||||
|
|
||||||
|
|
||||||
|
class TestMoveViewAferCat(torch.nn.Module):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(
|
||||||
|
x, [1, 1, 1, 1, 1, 1, 1]
|
||||||
|
)
|
||||||
|
getitem_71 = split_with_sizes_1[0]
|
||||||
|
getitem_72 = split_with_sizes_1[1]
|
||||||
|
getitem_73 = split_with_sizes_1[2]
|
||||||
|
getitem_74 = split_with_sizes_1[3]
|
||||||
|
getitem_75 = split_with_sizes_1[4]
|
||||||
|
getitem_76 = split_with_sizes_1[5]
|
||||||
|
getitem_77 = split_with_sizes_1[6]
|
||||||
|
view_1 = torch.ops.aten.view.default(getitem_71, [8, 96])
|
||||||
|
view_2 = torch.ops.aten.view.default(getitem_72, [8, 96])
|
||||||
|
view_3 = torch.ops.aten.view.default(getitem_73, [8, 96])
|
||||||
|
view_4 = torch.ops.aten.view.default(getitem_74, [8, 96])
|
||||||
|
view_5 = torch.ops.aten.view.default(getitem_75, [8, 96])
|
||||||
|
view_6 = torch.ops.aten.view.default(getitem_76, [8, 96])
|
||||||
|
view_7 = torch.ops.aten.view.default(getitem_77, [8, 96])
|
||||||
|
clone = torch.ops.aten.clone.default(view_1)
|
||||||
|
|
||||||
|
cat = torch.ops.aten.cat.default(
|
||||||
|
[
|
||||||
|
view_1,
|
||||||
|
view_2,
|
||||||
|
view_3,
|
||||||
|
view_4,
|
||||||
|
view_5,
|
||||||
|
view_6,
|
||||||
|
view_7,
|
||||||
|
],
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
cat_1 = torch.ops.aten.cat.default([clone, cat], 1)
|
||||||
|
return torch.cat([clone, cat_1], 1)
|
||||||
|
|
||||||
|
|
||||||
class TestSelectCat(torch.nn.Module):
|
class TestSelectCat(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -260,6 +300,30 @@ class TestSplitCatAten(TestCase):
|
||||||
self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8)
|
self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8)
|
||||||
counters.clear()
|
counters.clear()
|
||||||
|
|
||||||
|
@requires_cuda
|
||||||
|
@torch._inductor.config.patch(
|
||||||
|
pre_grad_fusion_options={},
|
||||||
|
post_grad_fusion_options={
|
||||||
|
"normalization_aten_pass": {},
|
||||||
|
"move_view_after_cat_aten_pass": {},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def test_move_view_after_cat_aten(self):
|
||||||
|
counters.clear()
|
||||||
|
inputs = [
|
||||||
|
torch.randn(7, 8, 96, device=torch.device(device=GPU_TYPE)),
|
||||||
|
]
|
||||||
|
module = TestMoveViewAferCat()
|
||||||
|
traced = torch.compile(module)
|
||||||
|
ref = module(*inputs)
|
||||||
|
res = traced(*inputs)
|
||||||
|
self.compare_pred(module, traced, inputs)
|
||||||
|
self.assertEqual(counters["inductor"]["normalization_aten_pass"], 4)
|
||||||
|
self.assertEqual(counters["inductor"]["move_view_after_cat_aten_pass"], 1)
|
||||||
|
self.assertEqual(ref, res, rtol=1e-8, atol=1e-8)
|
||||||
|
self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8)
|
||||||
|
counters.clear()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -72,6 +72,7 @@ post_grad_pass_names = [
|
||||||
"pad_aten_mm_pass",
|
"pad_aten_mm_pass",
|
||||||
"split_cat_aten_pass",
|
"split_cat_aten_pass",
|
||||||
"select_cat_aten_pass",
|
"select_cat_aten_pass",
|
||||||
|
"move_view_after_cat_aten_pass",
|
||||||
]
|
]
|
||||||
|
|
||||||
for pass_name in pre_grad_pass_names:
|
for pass_name in pre_grad_pass_names:
|
||||||
|
|
@ -2864,3 +2865,95 @@ def move_reshape_out_of_split_stack(match: Match, *args, **kwargs):
|
||||||
remove_split_unbind_children(graph, stack_inputs) # type: ignore[arg-type]
|
remove_split_unbind_children(graph, stack_inputs) # type: ignore[arg-type]
|
||||||
remove_split_unbind_children(graph, split_users) # type: ignore[arg-type]
|
remove_split_unbind_children(graph, split_users) # type: ignore[arg-type]
|
||||||
counters["inductor"]["move_reshape_out_of_split_stack_pass"] += 1
|
counters["inductor"]["move_reshape_out_of_split_stack_pass"] += 1
|
||||||
|
|
||||||
|
|
||||||
|
view_getitem_split_aten = ListOf(
|
||||||
|
CallFunction(
|
||||||
|
[torch.ops.aten.reshape.default],
|
||||||
|
CallFunction(
|
||||||
|
operator.getitem,
|
||||||
|
CallFunctionVarArgs(
|
||||||
|
torch.ops.aten.split_with_sizes.default, users=MULTIPLE
|
||||||
|
),
|
||||||
|
Ignored(),
|
||||||
|
_users=MULTIPLE,
|
||||||
|
),
|
||||||
|
Arg(),
|
||||||
|
_users=MULTIPLE,
|
||||||
|
),
|
||||||
|
partial=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_graph_pattern(
|
||||||
|
CallFunction(
|
||||||
|
torch.ops.aten.cat.default,
|
||||||
|
view_getitem_split_aten,
|
||||||
|
dim=Ignored(),
|
||||||
|
_users=MULTIPLE,
|
||||||
|
),
|
||||||
|
pass_dict=construct_pattern_matcher_pass("move_view_after_cat_aten_pass"),
|
||||||
|
)
|
||||||
|
def move_view_after_cat(match: Match, *args, **kwargs):
|
||||||
|
split_node = next(
|
||||||
|
node
|
||||||
|
for node in match.nodes
|
||||||
|
if node.target == torch.ops.aten.split_with_sizes.default
|
||||||
|
)
|
||||||
|
split_input, split_section, split_dim = _get_split_args_default(split_node)
|
||||||
|
split_users = list(split_node.users.keys())
|
||||||
|
getitem_indices = [
|
||||||
|
getitem.args[1] for getitem in split_users if getitem.target == operator.getitem
|
||||||
|
]
|
||||||
|
if not is_sorted_and_consecutive(getitem_indices): # type: ignore[arg-type]
|
||||||
|
return
|
||||||
|
cat_nodes = [
|
||||||
|
node for node in match.nodes if node.target == torch.ops.aten.cat.default
|
||||||
|
]
|
||||||
|
graph = match.graph
|
||||||
|
for cat_node in cat_nodes:
|
||||||
|
if not is_node_meta_valid(cat_node):
|
||||||
|
log.debug("example value absent for node: %s", cat_node)
|
||||||
|
continue
|
||||||
|
cat_dim = _get_dim(cat_node)
|
||||||
|
cat_inputs = get_arg_value(cat_node, 0, "tensors") # type: ignore[union-attr]
|
||||||
|
# we only consider the following special case
|
||||||
|
if len(cat_inputs) != len(split_section):
|
||||||
|
continue
|
||||||
|
# check if the cat inputs are all the view nodes
|
||||||
|
if not all(
|
||||||
|
view_node.target == torch.ops.aten.reshape.default
|
||||||
|
for view_node in cat_inputs
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
# check if the view nodes are all from getitem nodes
|
||||||
|
if not all(
|
||||||
|
view_node.args[0].target == operator.getitem for view_node in cat_inputs
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
view_indices = [view.args[0].args[1] for view in cat_inputs]
|
||||||
|
if not is_sorted_and_consecutive(view_indices): # type: ignore[arg-type]
|
||||||
|
continue
|
||||||
|
if cat_dim != split_dim:
|
||||||
|
# construct permute node
|
||||||
|
permute_list = list(range(len(cat_node.meta["val"].shape) + 1))
|
||||||
|
permute_list[split_dim], permute_list[cat_dim] = (
|
||||||
|
permute_list[cat_dim],
|
||||||
|
permute_list[split_dim],
|
||||||
|
)
|
||||||
|
permute_node = graph.call_function(
|
||||||
|
torch.ops.aten.permute.default,
|
||||||
|
args=(split_input, permute_list),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
permute_node = split_input
|
||||||
|
|
||||||
|
with graph.inserting_before(cat_node):
|
||||||
|
view_node = graph.call_function(
|
||||||
|
torch.ops.aten.reshape.default,
|
||||||
|
args=(permute_node, list(cat_node.meta["val"].shape)),
|
||||||
|
)
|
||||||
|
cat_node.replace_all_uses_with(view_node)
|
||||||
|
view_node.meta.update(cat_node.meta)
|
||||||
|
graph.erase_node(cat_node)
|
||||||
|
counters["inductor"]["move_view_after_cat_aten_pass"] += 1
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user