[dynamo] Ensure placeholder name is not an intermediate node name (#149712)

Fixes https://fb.workplace.com/groups/1075192433118967/permalink/1615671879071017/

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149712
Approved by: https://github.com/zou3519
This commit is contained in:
Animesh Jain 2025-03-21 11:23:44 -07:00 committed by PyTorch MergeBot
parent 7f836b747f
commit d320af0663

View File

@ -63,6 +63,7 @@ from torch.fx.experimental.symbolic_shapes import (
ShapeEnv,
)
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
from torch.utils._ordered_set import OrderedSet
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from . import config, exc, graph_break_hints, logging as torchdynamo_logging, variables
@ -2020,6 +2021,9 @@ class SubgraphTracer(fx.Tracer):
(self.graph._target_to_str(source_target), source_target)
]
# This is used to create a unique name for the placeholder
self._used_names: OrderedSet[str] = OrderedSet()
# preserve original meta if it is available
def _maybe_preserve_original_meta(self, tx, node):
if (
@ -2239,6 +2243,7 @@ class SubgraphTracer(fx.Tracer):
node = super().create_node(op, target, args, kwargs, name, type_expr)
node.meta["creation_timestamp"] = self.output_graph.timestamp
self._used_names.add(node.name)
return node
# Note: we did not override erase_node since
@ -2296,7 +2301,10 @@ class SubgraphTracer(fx.Tracer):
TracingContext.extract_stack()
)
name = get_unique_name_wrt(name, self.input_name_to_proxy)
# _used_names contains the names of all the nodes in the graph,
# including intermediates. This ensures that we do not have a name
# collision.
name = get_unique_name_wrt(name, self._used_names)
if self.input_name_to_proxy:
prev_name = next(reversed(self.input_name_to_proxy))
node = self.input_name_to_proxy[prev_name].node
@ -2316,6 +2324,11 @@ class SubgraphTracer(fx.Tracer):
else:
self.input_name_to_proxy[name] = proxy
# For placeholder nodes, `name` is passed as a str to the target,
# and then torch.fx decides the node.name. So, record the `target`
# name as well in the _used_names to prevent any collision.
self._used_names.add(name)
# NOTE: [Auto lift basic free symbols when create_graph_input]
# Whenever we call create_graph_input, we try to also lift the basic symbols in example values
# as graph input.