mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] Fix collectives_reordering overwrite real_dep with fake_dep with the same name (#158960)
Differential Revision: [D78839734](https://our.internmc.facebook.com/intern/diff/D78839734) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158960 Approved by: https://github.com/wconstab
This commit is contained in:
parent
3e954d3943
commit
3ced1079a4
|
|
@ -184,6 +184,10 @@ def _group_name(snode, with_bufs=False) -> str:
|
|||
return ret
|
||||
|
||||
|
||||
def _is_fake_dep(d):
|
||||
return isinstance(d, WeakDep) and d.is_fake
|
||||
|
||||
|
||||
def _reorder_communication_preserving_peak_memory_internal(
|
||||
snodes: list[BaseSchedulerNode],
|
||||
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]:
|
||||
|
|
@ -294,13 +298,17 @@ def _reorder_communication_preserving_peak_memory_internal(
|
|||
temp_grouping=True,
|
||||
)
|
||||
|
||||
data_deps = {s.name: s for s in group.unmet_dependencies}
|
||||
# We can have multiple deps with the same name.
|
||||
# As we ignore WeakDep(is_fake=True) =>
|
||||
# filter them out first to avoid overwriting of real dep.
|
||||
data_deps = {
|
||||
d.name: d for d in group.unmet_dependencies if not _is_fake_dep(d)
|
||||
}
|
||||
|
||||
candidate_outs = candidate.get_outputs()
|
||||
data_dep = None
|
||||
for o in candidate_outs:
|
||||
if d := data_deps.get(o.get_name(), None):
|
||||
if isinstance(d, WeakDep) and d.is_fake:
|
||||
continue
|
||||
data_dep = d
|
||||
break
|
||||
|
||||
|
|
@ -703,14 +711,20 @@ def _sink_waits_iterative_internal(
|
|||
_group_nodes(group_head, group_tail),
|
||||
temp_grouping=True,
|
||||
)
|
||||
group_outs = group.get_outputs()
|
||||
|
||||
data_deps = {s.name: s for s in candidate.unmet_dependencies}
|
||||
# We can have multiple deps with the same name.
|
||||
# As we ignore WeakDep(is_fake=True) =>
|
||||
# filter them out first to avoid overwriting of real dep.
|
||||
data_deps = {
|
||||
d.name: d
|
||||
for d in candidate.unmet_dependencies
|
||||
if not _is_fake_dep(d)
|
||||
}
|
||||
|
||||
group_outs = group.get_outputs()
|
||||
data_dep = None
|
||||
for o in group_outs:
|
||||
if d := data_deps.get(o.get_name(), None):
|
||||
if isinstance(d, WeakDep) and d.is_fake:
|
||||
continue
|
||||
data_dep = d
|
||||
break
|
||||
# 1. If we have data_dep - we can not swap => trying to group
|
||||
|
|
|
|||
|
|
@ -2151,6 +2151,23 @@ class Scheduler:
|
|||
OrderedSet(V.graph.get_output_names()),
|
||||
)
|
||||
if config.reorder_for_compute_comm_overlap:
|
||||
from torch._logging import trace_structured
|
||||
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "scheduler_nodes_before_comm_overlap",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: "\n\n".join(
|
||||
[
|
||||
f"snode[{i}]"
|
||||
+ n.debug_str()
|
||||
+ f" buffer_names:{n.get_buffer_names()}"
|
||||
for i, n in enumerate(self.nodes)
|
||||
]
|
||||
),
|
||||
)
|
||||
self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes)
|
||||
self.process_grouped_nodes()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user