mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix missing symbol when printing guards (#165723)
Fixes #165177 When converting guards to sources if we were unable to get the expected symbol from symbol_to_source then try to get it from var_to_sources. I was unable to make a simpler repro than what was described in the issue (which relies on llama3 - so inappropriate for a unit test). Pull Request resolved: https://github.com/pytorch/pytorch/pull/165723 Approved by: https://github.com/bobrenjc93
This commit is contained in:
parent
ba71e9ca9a
commit
a553ea9ea4
|
|
@ -2204,6 +2204,17 @@ class OutputGraph(OutputGraphCommon):
|
|||
|
||||
with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
|
||||
# TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting
|
||||
|
||||
# Why create a new FakeTensorMode?
|
||||
#
|
||||
# The reason this needs to be done is because when we do Dynamo tracing, fake
|
||||
# tensors can have their metadata mutated. Thus, the fake tensor we allocated
|
||||
# for any given tensor may no longer be valid for the beginning trace of the
|
||||
# graph. Nor is it convenient to "clone" the input tensors before mutating them,
|
||||
# since you have to preserve aliasing. So we just reconstruct the FakeTensorMode
|
||||
# from scratch when we go to AOTAutograd. But the ShapeEnv must be preserved as
|
||||
# Dynamo made decisions about what is dynamic or not / guards from the user code
|
||||
# that is not in graph.
|
||||
backend_fake_mode = torch._subclasses.FakeTensorMode(
|
||||
shape_env=old_fake_mode.shape_env,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2658,7 +2658,9 @@ class _ShapeGuardPrinter(abc.ABC):
|
|||
Convert a sympy Symbol to its source representation.
|
||||
|
||||
This method looks up the symbol in symbol_to_source mapping and returns
|
||||
the string representation of its first source.
|
||||
the string representation of its first source. If the symbol is not in
|
||||
symbol_to_source (which can happen when symbols appear in guard expressions
|
||||
through simplification or substitution), it falls back to var_to_sources.
|
||||
|
||||
Args:
|
||||
expr: The sympy Symbol to convert
|
||||
|
|
@ -2667,24 +2669,30 @@ class _ShapeGuardPrinter(abc.ABC):
|
|||
String representation of the symbol's source
|
||||
|
||||
Raises:
|
||||
AssertionError: If the symbol is not found in symbol_to_source
|
||||
AssertionError: If the symbol is not found in either mapping
|
||||
"""
|
||||
assert isinstance(expr, sympy.Symbol), str(type(expr))
|
||||
|
||||
def repr_symbol_to_source() -> str:
|
||||
# Try symbol_to_source first, fall back to var_to_sources if not found
|
||||
if source := self.symbol_to_source.get(expr):
|
||||
return self.print_source(source[0])
|
||||
elif source := self.var_to_sources.get(expr):
|
||||
return self.print_source(source[0])
|
||||
else:
|
||||
|
||||
def repr_sources(src: Mapping[sympy.Symbol, list[Source]]) -> str:
|
||||
return repr(
|
||||
{
|
||||
symbol: [s.name() for s in sources]
|
||||
for symbol, sources in self.symbol_to_source.items()
|
||||
for symbol, sources in src.items()
|
||||
}
|
||||
)
|
||||
|
||||
assert self.symbol_to_source.get(expr), (
|
||||
f"{expr} (could be from {[s.name() for s in self.var_to_sources[expr]]}) "
|
||||
f"not in {repr_symbol_to_source()}. If this assert is failing, it could be "
|
||||
"due to the issue described in https://github.com/pytorch/pytorch/pull/90665"
|
||||
raise RuntimeError(
|
||||
f"{expr} not in {repr_sources(self.symbol_to_source)} or "
|
||||
f"{repr_sources(self.var_to_sources)}. This could be due to "
|
||||
"the issue described in https://github.com/pytorch/pytorch/pull/90665"
|
||||
)
|
||||
return self.print_source(self.symbol_to_source[expr][0])
|
||||
|
||||
@abc.abstractmethod
|
||||
def print_source(self, source: Source) -> str:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user