Add shape_env guards to tracing context (#90876)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90876
Approved by: https://github.com/Chillee, https://github.com/ezyang
This commit is contained in:
Michael Voznesensky 2022-12-16 09:05:02 +00:00 committed by PyTorch MergeBot
parent a01c1ee594
commit 53e71fad8f
4 changed files with 22 additions and 10 deletions

View File

@ -196,13 +196,24 @@ class OutputGraph(fx.Tracer, Checkpointable[OutputGraphState]):
super(OutputGraph, self).__init__()
self.graph = torch.fx.Graph()
self.graphargs: List[GraphArg] = []
self.tracing_context: TracingContext = TracingContext()
fake_mode = torch._subclasses.FakeTensorMode(
throw_on_data_dependent_ops=True,
shape_env=ShapeEnv() if config.dynamic_shapes else None,
)
self.tracing_context: TracingContext = TracingContext(fake_mode)
# tracked_fakes says where any tensor that was wrapped to fake came
# from. It is similar to GraphArg, in that all GraphArgs will get
# will get added to TrackedFakes, but TrackedFakes also contains
# GraphArgs that got pruned, and things like Tensor attributes which
# aren't explicit graph inputs. Used by shape guard
self.tracked_fakes: List[TrackedFake] = []
# Although we prune unused graphargs before sending graphs to
# compilers, we may have legitimately triggered shape guards
# on "unused" inputs that we must keep track of. So after
# remove_unused_graphargs is called, orig_graphargs and
# graphargs no longer alias; orig_graphargs is the original
# graphargs, and graphargs is the pruned list. Guard creation
# should use original graphargs.
self.orig_graphargs: List[GraphArg] = self.graphargs
self.nn_modules: Optional[Dict[str, torch.nn.Module]] = dict()
self.side_effects = SideEffects()
@ -228,7 +239,6 @@ class OutputGraph(fx.Tracer, Checkpointable[OutputGraphState]):
self.unspec_variable_map: Dict[
str, Union[UnspecializedNumpyVariable, UnspecializedPythonVariable]
] = {}
self.shape_env = ShapeEnv() if config.dynamic_shapes else None
self.intermediary_symbols: Dict[sympy.Expr, None] = {}
# Enables creating unique node names by tracking
@ -245,6 +255,10 @@ class OutputGraph(fx.Tracer, Checkpointable[OutputGraphState]):
def fake_mode(self):
return self.root_tx.fake_mode
@property
def shape_env(self):
return self.tracing_context.fake_mode.shape_env
@property
def guards(self) -> Set[Guard]:
return self.tracing_context.guards_context.dynamo_guards

View File

@ -1577,10 +1577,7 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
# Flag to indicate whether tracing is used for export.
self.export = export
self._fake_mode = torch._subclasses.FakeTensorMode(
throw_on_data_dependent_ops=True,
shape_env=output.shape_env,
)
self._fake_mode = output.tracing_context.fake_mode
self.checkpoint = None
self.random_calls = []

View File

@ -285,8 +285,9 @@ class TracingContext:
def get() -> Optional["TracingContext"]:
return _CURRENT_TRACING_CONTEXT
def __init__(self):
def __init__(self, fake_mode):
self.guards_context = GuardsContext()
self.fake_mode = fake_mode
"""

View File

@ -753,7 +753,7 @@ class ShapeEnv(object):
try:
exprs.append(ShapeGuardPrinter(symbol_to_source).doprint(g))
except Exception:
log.warning(f"Failing guard allocated at:\n{tb}")
log.warning(f"Failing guard allocated at: \n{tb}")
raise
# 3. Every symbol must not be equal to 0/1
@ -821,7 +821,7 @@ class ShapeEnv(object):
return bindings
def get_nontrivial_guards(self):
return [self.simplify(guard) for guard, _ in self.guards if self._maybe_evaluate_static(guard) is None]
return [self.simplify(guard.expr) for guard in self.guards if self._maybe_evaluate_static(guard.expr) is None]
def format_guards(self, verbose=False):
def format_tb(tb):
@ -829,7 +829,7 @@ class ShapeEnv(object):
return ""
return f"\n Guarded at:\n{textwrap.indent(tb, ' ')}"
return '\n'.join(f" - {guard}{format_tb(tb)}" for guard, tb in self.guards)
return '\n'.join(f" - {guard.expr}{format_tb(guard.stack)}" for guard in self.guards)
def get_shape_groups(self):
shape_groups = collections.defaultdict(list)