Revert "[Dynamo] Fix typing in graph_deduplication.py (#152572)"

This reverts commit 15166be691.

Reverted https://github.com/pytorch/pytorch/pull/152572 on behalf of https://github.com/jeanschmidt due to Breaking internal signal citadel-fbcode-test-mode-opt-for-pt2_stack_for_internal-linux-0 please see diff [D74531503](https://www.internalfb.com/diff/D74531503) for more details ([comment](https://github.com/pytorch/pytorch/pull/152410#issuecomment-2871168679))
This commit is contained in:
PyTorch MergeBot 2025-05-12 07:15:09 +00:00
parent aa7fe6af41
commit 0071fdab9e

View File

@ -124,7 +124,7 @@ def _replace_region_with_subgraph(
external_node_usages: Iterable[OrderedSet[UsageIndex]],
inds_with_external_users: list[int],
subgraph_name: str,
node_to_additional_deps: dict[Node, OrderedSet[Node]],
node_to_additional_deps: dict[torch.fx.Node, OrderedSet[torch.fx.Node]],
node_to_mutated_arg_positions: dict[Node, OrderedSet[int]],
) -> None:
sub_args = []
@ -264,7 +264,7 @@ def _create_subgraph(
def _stable_topological_sort(
graph: torch.fx.Graph,
node_to_additional_deps: dict[Node, OrderedSet[Node]],
node_to_additional_deps: dict[torch.fx.Node, OrderedSet[torch.fx.Node]],
) -> None:
# Nodes are in exactly one of these four collections:
@ -273,14 +273,14 @@ def _stable_topological_sort(
# - Nodes in `ready` have been processed and are already in the correct
# order.
ready = OrderedSet[Node]()
ready = OrderedSet[torch.fx.Node]()
# - `waiting` is a mapping from a dependency to nodes which depend on that
# dependency.
waiting = defaultdict(list)
# - `outputs` are always at the end of the graph
outputs = OrderedSet[Node]()
outputs = OrderedSet[torch.fx.Node]()
# The cursor indicates the last processed node so we can add new nodes
# after it.
@ -386,7 +386,7 @@ def _add_mutation_dependencies(
def _has_aliasing(
region: Region, inputs: list[Node], inds_with_external_users: list[int]
) -> bool:
input_storages: dict[StorageWeakRef, Node] = dict()
input_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
for node in inputs:
example_value = node.meta["example_value"]
@ -403,7 +403,7 @@ def _has_aliasing(
return True
input_storages[storage] = node
output_storages: dict[StorageWeakRef, Node] = dict()
output_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
for i in inds_with_external_users:
out_node = region[i]
if out_node: