mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[Dynamo] Fix typing in graph_deduplication.py (#152572)"
This reverts commit 15166be691.
Reverted https://github.com/pytorch/pytorch/pull/152572 on behalf of https://github.com/jeanschmidt due to Breaking internal signal citadel-fbcode-test-mode-opt-for-pt2_stack_for_internal-linux-0 please see diff [D74531503](https://www.internalfb.com/diff/D74531503) for more details ([comment](https://github.com/pytorch/pytorch/pull/152410#issuecomment-2871168679))
This commit is contained in:
parent
aa7fe6af41
commit
0071fdab9e
|
|
@ -124,7 +124,7 @@ def _replace_region_with_subgraph(
|
||||||
external_node_usages: Iterable[OrderedSet[UsageIndex]],
|
external_node_usages: Iterable[OrderedSet[UsageIndex]],
|
||||||
inds_with_external_users: list[int],
|
inds_with_external_users: list[int],
|
||||||
subgraph_name: str,
|
subgraph_name: str,
|
||||||
node_to_additional_deps: dict[Node, OrderedSet[Node]],
|
node_to_additional_deps: dict[torch.fx.Node, OrderedSet[torch.fx.Node]],
|
||||||
node_to_mutated_arg_positions: dict[Node, OrderedSet[int]],
|
node_to_mutated_arg_positions: dict[Node, OrderedSet[int]],
|
||||||
) -> None:
|
) -> None:
|
||||||
sub_args = []
|
sub_args = []
|
||||||
|
|
@ -264,7 +264,7 @@ def _create_subgraph(
|
||||||
|
|
||||||
def _stable_topological_sort(
|
def _stable_topological_sort(
|
||||||
graph: torch.fx.Graph,
|
graph: torch.fx.Graph,
|
||||||
node_to_additional_deps: dict[Node, OrderedSet[Node]],
|
node_to_additional_deps: dict[torch.fx.Node, OrderedSet[torch.fx.Node]],
|
||||||
) -> None:
|
) -> None:
|
||||||
# Nodes are in exactly one of these four collections:
|
# Nodes are in exactly one of these four collections:
|
||||||
|
|
||||||
|
|
@ -273,14 +273,14 @@ def _stable_topological_sort(
|
||||||
|
|
||||||
# - Nodes in `ready` have been processed and are already in the correct
|
# - Nodes in `ready` have been processed and are already in the correct
|
||||||
# order.
|
# order.
|
||||||
ready = OrderedSet[Node]()
|
ready = OrderedSet[torch.fx.Node]()
|
||||||
|
|
||||||
# - `waiting` is a mapping from a dependency to nodes which depend on that
|
# - `waiting` is a mapping from a dependency to nodes which depend on that
|
||||||
# dependency.
|
# dependency.
|
||||||
waiting = defaultdict(list)
|
waiting = defaultdict(list)
|
||||||
|
|
||||||
# - `outputs` are always at the end of the graph
|
# - `outputs` are always at the end of the graph
|
||||||
outputs = OrderedSet[Node]()
|
outputs = OrderedSet[torch.fx.Node]()
|
||||||
|
|
||||||
# The cursor indicates the last processed node so we can add new nodes
|
# The cursor indicates the last processed node so we can add new nodes
|
||||||
# after it.
|
# after it.
|
||||||
|
|
@ -386,7 +386,7 @@ def _add_mutation_dependencies(
|
||||||
def _has_aliasing(
|
def _has_aliasing(
|
||||||
region: Region, inputs: list[Node], inds_with_external_users: list[int]
|
region: Region, inputs: list[Node], inds_with_external_users: list[int]
|
||||||
) -> bool:
|
) -> bool:
|
||||||
input_storages: dict[StorageWeakRef, Node] = dict()
|
input_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
|
||||||
|
|
||||||
for node in inputs:
|
for node in inputs:
|
||||||
example_value = node.meta["example_value"]
|
example_value = node.meta["example_value"]
|
||||||
|
|
@ -403,7 +403,7 @@ def _has_aliasing(
|
||||||
return True
|
return True
|
||||||
input_storages[storage] = node
|
input_storages[storage] = node
|
||||||
|
|
||||||
output_storages: dict[StorageWeakRef, Node] = dict()
|
output_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
|
||||||
for i in inds_with_external_users:
|
for i in inds_with_external_users:
|
||||||
out_node = region[i]
|
out_node = region[i]
|
||||||
if out_node:
|
if out_node:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user