[inductor] Fix collectives_reordering overwrite real_dep with fake_dep with the same name (#158960)

Differential Revision: [D78839734](https://our.internmc.facebook.com/intern/diff/D78839734)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158960
Approved by: https://github.com/wconstab
This commit is contained in:
IvanKobzarev 2025-07-23 13:47:29 -07:00 committed by PyTorch MergeBot
parent 3e954d3943
commit 3ced1079a4
2 changed files with 38 additions and 7 deletions

View File

@ -184,6 +184,10 @@ def _group_name(snode, with_bufs=False) -> str:
return ret
def _is_fake_dep(d):
return isinstance(d, WeakDep) and d.is_fake
def _reorder_communication_preserving_peak_memory_internal(
snodes: list[BaseSchedulerNode],
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]:
@ -294,13 +298,17 @@ def _reorder_communication_preserving_peak_memory_internal(
temp_grouping=True,
)
data_deps = {s.name: s for s in group.unmet_dependencies}
# We can have multiple deps with the same name.
# As we ignore WeakDep(is_fake=True) =>
# filter them out first to avoid overwriting of real dep.
data_deps = {
d.name: d for d in group.unmet_dependencies if not _is_fake_dep(d)
}
candidate_outs = candidate.get_outputs()
data_dep = None
for o in candidate_outs:
if d := data_deps.get(o.get_name(), None):
if isinstance(d, WeakDep) and d.is_fake:
continue
data_dep = d
break
@ -703,14 +711,20 @@ def _sink_waits_iterative_internal(
_group_nodes(group_head, group_tail),
temp_grouping=True,
)
group_outs = group.get_outputs()
data_deps = {s.name: s for s in candidate.unmet_dependencies}
# We can have multiple deps with the same name.
# As we ignore WeakDep(is_fake=True) =>
# filter them out first to avoid overwriting of real dep.
data_deps = {
d.name: d
for d in candidate.unmet_dependencies
if not _is_fake_dep(d)
}
group_outs = group.get_outputs()
data_dep = None
for o in group_outs:
if d := data_deps.get(o.get_name(), None):
if isinstance(d, WeakDep) and d.is_fake:
continue
data_dep = d
break
# 1. If we have data_dep - we can not swap => trying to group

View File

@ -2151,6 +2151,23 @@ class Scheduler:
OrderedSet(V.graph.get_output_names()),
)
if config.reorder_for_compute_comm_overlap:
from torch._logging import trace_structured
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "scheduler_nodes_before_comm_overlap",
"encoding": "string",
},
payload_fn=lambda: "\n\n".join(
[
f"snode[{i}]"
+ n.debug_str()
+ f" buffer_names:{n.get_buffer_names()}"
for i, n in enumerate(self.nodes)
]
),
)
self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes)
self.process_grouped_nodes()