Fix missing symbol when printing guards (#165723)

Fixes #165177

When converting guards to sources if we were unable to get the expected symbol from symbol_to_source then try to get it from var_to_sources.

I was unable to make a simpler repro than what was described in the issue (which relies on llama3 - so inappropriate for a unit test).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165723
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein 2025-10-29 14:31:28 -07:00 committed by PyTorch MergeBot
parent ba71e9ca9a
commit a553ea9ea4
2 changed files with 34 additions and 15 deletions

View File

@ -2204,6 +2204,17 @@ class OutputGraph(OutputGraphCommon):
with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
# TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting
# Why create a new FakeTensorMode?
#
# The reason this needs to be done is because when we do Dynamo tracing, fake
# tensors can have their metadata mutated. Thus, the fake tensor we allocated
# for any given tensor may no longer be valid for the beginning trace of the
# graph. Nor is it convenient to "clone" the input tensors before mutating them,
# since you have to preserve aliasing. So we just reconstruct the FakeTensorMode
# from scratch when we go to AOTAutograd. But the ShapeEnv must be preserved as
# Dynamo made decisions about what is dynamic or not / guards from the user code
# that is not in graph.
backend_fake_mode = torch._subclasses.FakeTensorMode(
shape_env=old_fake_mode.shape_env,
)

View File

@ -2658,7 +2658,9 @@ class _ShapeGuardPrinter(abc.ABC):
Convert a sympy Symbol to its source representation.
This method looks up the symbol in symbol_to_source mapping and returns
the string representation of its first source.
the string representation of its first source. If the symbol is not in
symbol_to_source (which can happen when symbols appear in guard expressions
through simplification or substitution), it falls back to var_to_sources.
Args:
expr: The sympy Symbol to convert
@ -2667,24 +2669,30 @@ class _ShapeGuardPrinter(abc.ABC):
String representation of the symbol's source
Raises:
AssertionError: If the symbol is not found in symbol_to_source
AssertionError: If the symbol is not found in either mapping
"""
assert isinstance(expr, sympy.Symbol), str(type(expr))
def repr_symbol_to_source() -> str:
return repr(
{
symbol: [s.name() for s in sources]
for symbol, sources in self.symbol_to_source.items()
}
)
# Try symbol_to_source first, fall back to var_to_sources if not found
if source := self.symbol_to_source.get(expr):
return self.print_source(source[0])
elif source := self.var_to_sources.get(expr):
return self.print_source(source[0])
else:
assert self.symbol_to_source.get(expr), (
f"{expr} (could be from {[s.name() for s in self.var_to_sources[expr]]}) "
f"not in {repr_symbol_to_source()}. If this assert is failing, it could be "
"due to the issue described in https://github.com/pytorch/pytorch/pull/90665"
)
return self.print_source(self.symbol_to_source[expr][0])
def repr_sources(src: Mapping[sympy.Symbol, list[Source]]) -> str:
return repr(
{
symbol: [s.name() for s in sources]
for symbol, sources in src.items()
}
)
raise RuntimeError(
f"{expr} not in {repr_sources(self.symbol_to_source)} or "
f"{repr_sources(self.var_to_sources)}. This could be due to "
"the issue described in https://github.com/pytorch/pytorch/pull/90665"
)
@abc.abstractmethod
def print_source(self, source: Source) -> str: