mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] Fix collectives_reordering overwrite real_dep with fake_dep with the same name (#158960)
Differential Revision: [D78839734](https://our.internmc.facebook.com/intern/diff/D78839734) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158960 Approved by: https://github.com/wconstab
This commit is contained in:
parent
3e954d3943
commit
3ced1079a4
|
|
@ -184,6 +184,10 @@ def _group_name(snode, with_bufs=False) -> str:
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def _is_fake_dep(d):
|
||||||
|
return isinstance(d, WeakDep) and d.is_fake
|
||||||
|
|
||||||
|
|
||||||
def _reorder_communication_preserving_peak_memory_internal(
|
def _reorder_communication_preserving_peak_memory_internal(
|
||||||
snodes: list[BaseSchedulerNode],
|
snodes: list[BaseSchedulerNode],
|
||||||
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]:
|
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]:
|
||||||
|
|
@ -294,13 +298,17 @@ def _reorder_communication_preserving_peak_memory_internal(
|
||||||
temp_grouping=True,
|
temp_grouping=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
data_deps = {s.name: s for s in group.unmet_dependencies}
|
# We can have multiple deps with the same name.
|
||||||
|
# As we ignore WeakDep(is_fake=True) =>
|
||||||
|
# filter them out first to avoid overwriting of real dep.
|
||||||
|
data_deps = {
|
||||||
|
d.name: d for d in group.unmet_dependencies if not _is_fake_dep(d)
|
||||||
|
}
|
||||||
|
|
||||||
candidate_outs = candidate.get_outputs()
|
candidate_outs = candidate.get_outputs()
|
||||||
data_dep = None
|
data_dep = None
|
||||||
for o in candidate_outs:
|
for o in candidate_outs:
|
||||||
if d := data_deps.get(o.get_name(), None):
|
if d := data_deps.get(o.get_name(), None):
|
||||||
if isinstance(d, WeakDep) and d.is_fake:
|
|
||||||
continue
|
|
||||||
data_dep = d
|
data_dep = d
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
@ -703,14 +711,20 @@ def _sink_waits_iterative_internal(
|
||||||
_group_nodes(group_head, group_tail),
|
_group_nodes(group_head, group_tail),
|
||||||
temp_grouping=True,
|
temp_grouping=True,
|
||||||
)
|
)
|
||||||
group_outs = group.get_outputs()
|
|
||||||
|
|
||||||
data_deps = {s.name: s for s in candidate.unmet_dependencies}
|
# We can have multiple deps with the same name.
|
||||||
|
# As we ignore WeakDep(is_fake=True) =>
|
||||||
|
# filter them out first to avoid overwriting of real dep.
|
||||||
|
data_deps = {
|
||||||
|
d.name: d
|
||||||
|
for d in candidate.unmet_dependencies
|
||||||
|
if not _is_fake_dep(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
group_outs = group.get_outputs()
|
||||||
data_dep = None
|
data_dep = None
|
||||||
for o in group_outs:
|
for o in group_outs:
|
||||||
if d := data_deps.get(o.get_name(), None):
|
if d := data_deps.get(o.get_name(), None):
|
||||||
if isinstance(d, WeakDep) and d.is_fake:
|
|
||||||
continue
|
|
||||||
data_dep = d
|
data_dep = d
|
||||||
break
|
break
|
||||||
# 1. If we have data_dep - we can not swap => trying to group
|
# 1. If we have data_dep - we can not swap => trying to group
|
||||||
|
|
|
||||||
|
|
@ -2151,6 +2151,23 @@ class Scheduler:
|
||||||
OrderedSet(V.graph.get_output_names()),
|
OrderedSet(V.graph.get_output_names()),
|
||||||
)
|
)
|
||||||
if config.reorder_for_compute_comm_overlap:
|
if config.reorder_for_compute_comm_overlap:
|
||||||
|
from torch._logging import trace_structured
|
||||||
|
|
||||||
|
trace_structured(
|
||||||
|
"artifact",
|
||||||
|
metadata_fn=lambda: {
|
||||||
|
"name": "scheduler_nodes_before_comm_overlap",
|
||||||
|
"encoding": "string",
|
||||||
|
},
|
||||||
|
payload_fn=lambda: "\n\n".join(
|
||||||
|
[
|
||||||
|
f"snode[{i}]"
|
||||||
|
+ n.debug_str()
|
||||||
|
+ f" buffer_names:{n.get_buffer_names()}"
|
||||||
|
for i, n in enumerate(self.nodes)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes)
|
self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes)
|
||||||
self.process_grouped_nodes()
|
self.process_grouped_nodes()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user