mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[Dynamo] Fix typing in graph_deduplication.py (#152572)"
This reverts commit 15166be691.
Reverted https://github.com/pytorch/pytorch/pull/152572 on behalf of https://github.com/jeanschmidt due to Breaking internal signal citadel-fbcode-test-mode-opt-for-pt2_stack_for_internal-linux-0 please see diff [D74531503](https://www.internalfb.com/diff/D74531503) for more details ([comment](https://github.com/pytorch/pytorch/pull/152410#issuecomment-2871168679))
This commit is contained in:
parent
aa7fe6af41
commit
0071fdab9e
|
|
@ -124,7 +124,7 @@ def _replace_region_with_subgraph(
|
|||
external_node_usages: Iterable[OrderedSet[UsageIndex]],
|
||||
inds_with_external_users: list[int],
|
||||
subgraph_name: str,
|
||||
node_to_additional_deps: dict[Node, OrderedSet[Node]],
|
||||
node_to_additional_deps: dict[torch.fx.Node, OrderedSet[torch.fx.Node]],
|
||||
node_to_mutated_arg_positions: dict[Node, OrderedSet[int]],
|
||||
) -> None:
|
||||
sub_args = []
|
||||
|
|
@ -264,7 +264,7 @@ def _create_subgraph(
|
|||
|
||||
def _stable_topological_sort(
|
||||
graph: torch.fx.Graph,
|
||||
node_to_additional_deps: dict[Node, OrderedSet[Node]],
|
||||
node_to_additional_deps: dict[torch.fx.Node, OrderedSet[torch.fx.Node]],
|
||||
) -> None:
|
||||
# Nodes are in exactly one of these four collections:
|
||||
|
||||
|
|
@ -273,14 +273,14 @@ def _stable_topological_sort(
|
|||
|
||||
# - Nodes in `ready` have been processed and are already in the correct
|
||||
# order.
|
||||
ready = OrderedSet[Node]()
|
||||
ready = OrderedSet[torch.fx.Node]()
|
||||
|
||||
# - `waiting` is a mapping from a dependency to nodes which depend on that
|
||||
# dependency.
|
||||
waiting = defaultdict(list)
|
||||
|
||||
# - `outputs` are always at the end of the graph
|
||||
outputs = OrderedSet[Node]()
|
||||
outputs = OrderedSet[torch.fx.Node]()
|
||||
|
||||
# The cursor indicates the last processed node so we can add new nodes
|
||||
# after it.
|
||||
|
|
@ -386,7 +386,7 @@ def _add_mutation_dependencies(
|
|||
def _has_aliasing(
|
||||
region: Region, inputs: list[Node], inds_with_external_users: list[int]
|
||||
) -> bool:
|
||||
input_storages: dict[StorageWeakRef, Node] = dict()
|
||||
input_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
|
||||
|
||||
for node in inputs:
|
||||
example_value = node.meta["example_value"]
|
||||
|
|
@ -403,7 +403,7 @@ def _has_aliasing(
|
|||
return True
|
||||
input_storages[storage] = node
|
||||
|
||||
output_storages: dict[StorageWeakRef, Node] = dict()
|
||||
output_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
|
||||
for i in inds_with_external_users:
|
||||
out_node = region[i]
|
||||
if out_node:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user