[inductor] skip bmm when converting channel last (#159459)

Workaround of #159458 by remove some nodes output channel last set

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159459
Approved by: https://github.com/etaf, https://github.com/eellison, https://github.com/shunting314
This commit is contained in:
jianyizh 2025-09-26 09:11:36 +00:00 committed by PyTorch MergeBot
parent 4783e3ff49
commit 6a2bd1f4ee

View File

@ -851,12 +851,17 @@ class GraphLowering(torch.fx.Interpreter):
With rule 2, we makes sure all the tensors in the chain uses channels last layout. So both copies
can be saved.
"""
last_conv = None
nodes_cannot_propagate = [torch.ops.aten.bmm.default]
output_set = OrderedSet[Node]()
for n in reversed(self.module.graph.nodes): # type: ignore[arg-type, union-attr]
if n.target == torch.ops.aten.convolution.default:
output_set.add(n)
if last_conv is None:
last_conv = n
continue
if n.target in nodes_cannot_propagate:
continue
for user in n.users:
if user in output_set:
output_set.add(n)
@ -877,8 +882,14 @@ class GraphLowering(torch.fx.Interpreter):
# - res2net50_14w_8s
# - sebotnet33ts_256
for n in self.module.graph.nodes: # type: ignore[union-attr]
# layout propagation ends at last conv node, which will benefit vison transformers.
if last_conv is not None and n == last_conv:
break
if n in output_set:
output_set.update(n.users)
for user in n.users:
if user.target in nodes_cannot_propagate:
continue
output_set.add(user)
return output_set