mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add shape_env guards to tracing context (#90876)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90876 Approved by: https://github.com/Chillee, https://github.com/ezyang
This commit is contained in:
parent
a01c1ee594
commit
53e71fad8f
|
|
@ -196,13 +196,24 @@ class OutputGraph(fx.Tracer, Checkpointable[OutputGraphState]):
|
|||
super(OutputGraph, self).__init__()
|
||||
self.graph = torch.fx.Graph()
|
||||
self.graphargs: List[GraphArg] = []
|
||||
self.tracing_context: TracingContext = TracingContext()
|
||||
fake_mode = torch._subclasses.FakeTensorMode(
|
||||
throw_on_data_dependent_ops=True,
|
||||
shape_env=ShapeEnv() if config.dynamic_shapes else None,
|
||||
)
|
||||
self.tracing_context: TracingContext = TracingContext(fake_mode)
|
||||
# tracked_fakes says where any tensor that was wrapped to fake came
|
||||
# from. It is similar to GraphArg, in that all GraphArgs will get
|
||||
# will get added to TrackedFakes, but TrackedFakes also contains
|
||||
# GraphArgs that got pruned, and things like Tensor attributes which
|
||||
# aren't explicit graph inputs. Used by shape guard
|
||||
self.tracked_fakes: List[TrackedFake] = []
|
||||
# Although we prune unused graphargs before sending graphs to
|
||||
# compilers, we may have legitimately triggered shape guards
|
||||
# on "unused" inputs that we must keep track of. So after
|
||||
# remove_unused_graphargs is called, orig_graphargs and
|
||||
# graphargs no longer alias; orig_graphargs is the original
|
||||
# graphargs, and graphargs is the pruned list. Guard creation
|
||||
# should use original graphargs.
|
||||
self.orig_graphargs: List[GraphArg] = self.graphargs
|
||||
self.nn_modules: Optional[Dict[str, torch.nn.Module]] = dict()
|
||||
self.side_effects = SideEffects()
|
||||
|
|
@ -228,7 +239,6 @@ class OutputGraph(fx.Tracer, Checkpointable[OutputGraphState]):
|
|||
self.unspec_variable_map: Dict[
|
||||
str, Union[UnspecializedNumpyVariable, UnspecializedPythonVariable]
|
||||
] = {}
|
||||
self.shape_env = ShapeEnv() if config.dynamic_shapes else None
|
||||
self.intermediary_symbols: Dict[sympy.Expr, None] = {}
|
||||
|
||||
# Enables creating unique node names by tracking
|
||||
|
|
@ -245,6 +255,10 @@ class OutputGraph(fx.Tracer, Checkpointable[OutputGraphState]):
|
|||
def fake_mode(self):
|
||||
return self.root_tx.fake_mode
|
||||
|
||||
@property
|
||||
def shape_env(self):
|
||||
return self.tracing_context.fake_mode.shape_env
|
||||
|
||||
@property
|
||||
def guards(self) -> Set[Guard]:
|
||||
return self.tracing_context.guards_context.dynamo_guards
|
||||
|
|
|
|||
|
|
@ -1577,10 +1577,7 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
|
|||
# Flag to indicate whether tracing is used for export.
|
||||
self.export = export
|
||||
|
||||
self._fake_mode = torch._subclasses.FakeTensorMode(
|
||||
throw_on_data_dependent_ops=True,
|
||||
shape_env=output.shape_env,
|
||||
)
|
||||
self._fake_mode = output.tracing_context.fake_mode
|
||||
|
||||
self.checkpoint = None
|
||||
self.random_calls = []
|
||||
|
|
|
|||
|
|
@ -285,8 +285,9 @@ class TracingContext:
|
|||
def get() -> Optional["TracingContext"]:
|
||||
return _CURRENT_TRACING_CONTEXT
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, fake_mode):
|
||||
self.guards_context = GuardsContext()
|
||||
self.fake_mode = fake_mode
|
||||
|
||||
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -753,7 +753,7 @@ class ShapeEnv(object):
|
|||
try:
|
||||
exprs.append(ShapeGuardPrinter(symbol_to_source).doprint(g))
|
||||
except Exception:
|
||||
log.warning(f"Failing guard allocated at:\n{tb}")
|
||||
log.warning(f"Failing guard allocated at: \n{tb}")
|
||||
raise
|
||||
|
||||
# 3. Every symbol must not be equal to 0/1
|
||||
|
|
@ -821,7 +821,7 @@ class ShapeEnv(object):
|
|||
return bindings
|
||||
|
||||
def get_nontrivial_guards(self):
|
||||
return [self.simplify(guard) for guard, _ in self.guards if self._maybe_evaluate_static(guard) is None]
|
||||
return [self.simplify(guard.expr) for guard in self.guards if self._maybe_evaluate_static(guard.expr) is None]
|
||||
|
||||
def format_guards(self, verbose=False):
|
||||
def format_tb(tb):
|
||||
|
|
@ -829,7 +829,7 @@ class ShapeEnv(object):
|
|||
return ""
|
||||
return f"\n Guarded at:\n{textwrap.indent(tb, ' ')}"
|
||||
|
||||
return '\n'.join(f" - {guard}{format_tb(tb)}" for guard, tb in self.guards)
|
||||
return '\n'.join(f" - {guard.expr}{format_tb(guard.stack)}" for guard in self.guards)
|
||||
|
||||
def get_shape_groups(self):
|
||||
shape_groups = collections.defaultdict(list)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user