[inductor][debug] fix draw_buffers (#135266)

**Before:**
![image](https://github.com/user-attachments/assets/aac756f3-1349-4647-9da3-87cf105cf647)

**After:**
<img width="791" alt="image" src="https://github.com/user-attachments/assets/d72c663c-e598-42fa-ac40-9e58956f1ec1">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135266
Approved by: https://github.com/yf225
This commit is contained in:
Xuan Zhang 2024-09-06 04:12:39 +00:00 committed by PyTorch MergeBot
parent 5f57be7571
commit c05a7adb36

View File

@ -111,6 +111,7 @@ def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
FusionMeta = collections.namedtuple("FusionMeta", ["group", "snode", "type"])
buf_to_fx_node = {}
node_to_fx_node = {}
graph = torch.fx.Graph()
first_node = None
@ -162,10 +163,9 @@ def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
fx_node.meta["fusion_meta"] = FusionMeta(group, snode, node_type)
if isinstance(snode, FusedSchedulerNode):
for x in snode.snodes:
buf_to_fx_node[x.get_name()] = fx_node
buf_to_fx_node[name] = fx_node
node_to_fx_node[name] = fx_node
for buf in snode.get_outputs():
buf_to_fx_node[buf.get_name()] = fx_node
if first_node is None:
first_node = fx_node
@ -175,7 +175,7 @@ def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
name = snode.get_name()
deps = snode.read_writes.reads
fx_node = buf_to_fx_node[name]
fx_node = node_to_fx_node[name]
new_args = []
for dep in deps:
if dep.name in buf_to_fx_node:
@ -184,6 +184,8 @@ def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
with graph.inserting_before(first_node):
dep_node = graph.placeholder(dep.name)
buf_to_fx_node[dep.name] = dep_node
if dep_node == fx_node: # to avoid cycles
continue
new_args.append(dep_node)
fx_node.args = tuple(new_args)