mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] skip bmm when converting channel last (#159459)
Workaround of #159458 by remove some nodes output channel last set Pull Request resolved: https://github.com/pytorch/pytorch/pull/159459 Approved by: https://github.com/etaf, https://github.com/eellison, https://github.com/shunting314
This commit is contained in:
parent
4783e3ff49
commit
6a2bd1f4ee
|
|
@ -851,12 +851,17 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
With rule 2, we makes sure all the tensors in the chain uses channels last layout. So both copies
|
||||
can be saved.
|
||||
"""
|
||||
last_conv = None
|
||||
nodes_cannot_propagate = [torch.ops.aten.bmm.default]
|
||||
output_set = OrderedSet[Node]()
|
||||
for n in reversed(self.module.graph.nodes): # type: ignore[arg-type, union-attr]
|
||||
if n.target == torch.ops.aten.convolution.default:
|
||||
output_set.add(n)
|
||||
if last_conv is None:
|
||||
last_conv = n
|
||||
continue
|
||||
if n.target in nodes_cannot_propagate:
|
||||
continue
|
||||
|
||||
for user in n.users:
|
||||
if user in output_set:
|
||||
output_set.add(n)
|
||||
|
|
@ -877,8 +882,14 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
# - res2net50_14w_8s
|
||||
# - sebotnet33ts_256
|
||||
for n in self.module.graph.nodes: # type: ignore[union-attr]
|
||||
# layout propagation ends at last conv node, which will benefit vison transformers.
|
||||
if last_conv is not None and n == last_conv:
|
||||
break
|
||||
if n in output_set:
|
||||
output_set.update(n.users)
|
||||
for user in n.users:
|
||||
if user.target in nodes_cannot_propagate:
|
||||
continue
|
||||
output_set.add(user)
|
||||
|
||||
return output_set
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user