mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Ensure placeholder name is not an intermediate node name (#149712)
Fixes https://fb.workplace.com/groups/1075192433118967/permalink/1615671879071017/ Pull Request resolved: https://github.com/pytorch/pytorch/pull/149712 Approved by: https://github.com/zou3519
This commit is contained in:
parent
7f836b747f
commit
d320af0663
|
|
@ -63,6 +63,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
|||
ShapeEnv,
|
||||
)
|
||||
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
|
||||
from . import config, exc, graph_break_hints, logging as torchdynamo_logging, variables
|
||||
|
|
@ -2020,6 +2021,9 @@ class SubgraphTracer(fx.Tracer):
|
|||
(self.graph._target_to_str(source_target), source_target)
|
||||
]
|
||||
|
||||
# This is used to create a unique name for the placeholder
|
||||
self._used_names: OrderedSet[str] = OrderedSet()
|
||||
|
||||
# preserve original meta if it is available
|
||||
def _maybe_preserve_original_meta(self, tx, node):
|
||||
if (
|
||||
|
|
@ -2239,6 +2243,7 @@ class SubgraphTracer(fx.Tracer):
|
|||
|
||||
node = super().create_node(op, target, args, kwargs, name, type_expr)
|
||||
node.meta["creation_timestamp"] = self.output_graph.timestamp
|
||||
self._used_names.add(node.name)
|
||||
return node
|
||||
|
||||
# Note: we did not override erase_node since
|
||||
|
|
@ -2296,7 +2301,10 @@ class SubgraphTracer(fx.Tracer):
|
|||
TracingContext.extract_stack()
|
||||
)
|
||||
|
||||
name = get_unique_name_wrt(name, self.input_name_to_proxy)
|
||||
# _used_names contains the names of all the nodes in the graph,
|
||||
# including intermediates. This ensures that we do not have a name
|
||||
# collision.
|
||||
name = get_unique_name_wrt(name, self._used_names)
|
||||
if self.input_name_to_proxy:
|
||||
prev_name = next(reversed(self.input_name_to_proxy))
|
||||
node = self.input_name_to_proxy[prev_name].node
|
||||
|
|
@ -2316,6 +2324,11 @@ class SubgraphTracer(fx.Tracer):
|
|||
else:
|
||||
self.input_name_to_proxy[name] = proxy
|
||||
|
||||
# For placeholder nodes, `name` is passed as a str to the target,
|
||||
# and then torch.fx decides the node.name. So, record the `target`
|
||||
# name as well in the _used_names to prevent any collision.
|
||||
self._used_names.add(name)
|
||||
|
||||
# NOTE: [Auto lift basic free symbols when create_graph_input]
|
||||
# Whenever we call create_graph_input, we try to also lift the basic symbols in example values
|
||||
# as graph input.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user